From 326f925361ff2c1bd85ea805aa63c836912bcf9f Mon Sep 17 00:00:00 2001 From: Hangxing Wei Date: Mon, 22 Apr 2024 13:27:46 +0800 Subject: [PATCH] Feat/modelhub (#1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: translations (#3176) * fix: prompt editor variable picker (#3177) * Fix: features of agent-chat (#3178) * version to 0.6.0-fix1 (#3179) * fix keyword index error when storage source is S3 (#3182) * Update README.md to include workflows (#3180) * Compatible with unique index conflicts (#3183) * fix: sometimes chosed old selected knowledge may overwirte the new knowledge (#3199) * Fix: remove unavailable return_preamble parameter in cohere (#3201) Signed-off-by: Jat * Fix/code transform result (#3203) * fix(code_executor): surrogates not allowed error in jinja2 template (#3191) * fix: node connect self (#3194) * Update README.md (#3206) * fix economy index search in workflow (#3205) * fix: index number in api/README (#3214) * Update README.md (#3212) * fix detached instance error in keyword index create thread and fix question classifier node out of index error (#3219) * fix: incomplete response (#3215) * fix: latest image tag not push in GitHub action (#3213) * fix: vision config doesn't enabled in llm (#3225) * fixed the issue of missing cleanup function in the AudioBtn component (#3133) * fix: image text when retrieve chat histories (#3220) * feat: moonshot function call (#3227) * feat: support setting database used in Milvus (#3003) * fix milvus database name parameter missed (#3229) * fix: file not uploaded caused api error (#3228) * update link (#3226) * fix: skip Celery warning by setting broker_connection_retry_on_startup config (#3188) * fix: workflow run edge status (#3236) * fix: empty conversation list of explore chatbot (#3235) * Fix: picture of workflow (#3241) * feat: prompt-editor support undo (#3242) * fix: number type in app would render as select type in webapp (#3244) * fix: token is not logging of question classifier node (#3249) * chore: remove langchain in tools (#3247) * make sure validation flow works for all model providers in bedrock (#3250) * feat: remove unregistered-llm-in-debug (#3251) * version to 0.6.1 (#3253) * fix: agent chat multiple model debug (#3258) * feat: gpt-4-turbo (#3263) * fix: image was sent to an unsupported LLM when sending second message (#3268) * feat: vision parameter support of OpenAI Compatible API (#3272) * fix: var assigner input node can not find caused error (#3274) * fix: variable-assigner node connect (#3288) * Feat/Agent-Image-Processing (#3293) Co-authored-by: Joel * chore: address security alerts on braces escape and KaTeX (#3301) * chore(deps): bump katex from 0.16.8 to 0.16.10 in /web (#3307) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Update README.md (#3281) * Remove langchain dataset retrival agent logic (#3311) * add german translations (#3322) * feat: add missing workflow i18n keys (#3309) Co-authored-by: lbm21 <313338264@qq.com> * feat:add 'name' field return (#3152) * improvement: speed up dependency installation in docker image rebuilds by mounting cache layer (#3218) * feat: support gpt-4-turbo-2024-04-09 model (#3300) * feat: Add Cohere Command R / R+ model support (#3333) * fix dataset retrival in dataset mode (#3334) * chore:bump pypdfium2 from 4.16.0 to 4.17.0 (#3310) * feat(llm/models): add gemini-1.5-pro (#2925) * feat: make input size bigger in start (#3340) * Doc/update readme (#3344) * fix: leave progress page still call indexing-status api (#3345) * feat: update aws bedrock (#3326) Co-authored-by: chenhe * fix/moonshot-function-call (#3339) * fix issue: user’s keywords do not affect when add segment (#3349) * add segment with keyword issue (#3351) Co-authored-by: StyleZhang * Fix issue : don't delete DatasetProcessRule, DatasetQuery and AppDatasetJoin when delete dataset with no document (#3354) * fix: remove middle editor may cause render placement error (#3356) * Added a note on the front-end docker build: use taobao source to accelerate the installation of front-end dependency packages to achieve the purpose of quickly building containers (#3358) Co-authored-by: lbm21 <313338264@qq.com> Co-authored-by: akou * fix: var name too long would break ui in var assigner and end nodes (#3361) * Refactor/react agent (#3355) * Fix/Bing Search url endpoint cannot be customized (#3366) * fix: image token calc of OpenAI Compatible API (#3368) * Update README.md (#3371) * update workflow intro mp4 codec (#3372) * fix: cohere tool call does not support single tool (#3373) * version to 0.6.2 (#3375) * fix: variable pool mapping variable mixed up (#3378) * version to 0.6.2-fix1 (#3380) * fix: yarn install extract package err when using GitHub Cache in amd6… (#3383) * feat: Add support for embed file with AWS Bedrock Titan Model (#3377) Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> * fix: remove - in dataset retriever tool name (#3381) * feat:api Add support for extracting EPUB files in ExtractProcessor (#3254) Co-authored-by: crazywoola <427733928@qq.com> * feat: show citation info in run history (#3399) * Feat: Invitation link automatically completes domain name (#3393) Co-authored-by: huangbaichao * Integrated SearXNG search as built-in tool (#3363) Co-authored-by: crazywoola <427733928@qq.com> * fix: [azure_openai] Error: 'NoneType' object has no attribute 'content' (#3389) * Update providers preview (#3403) * add xls file suport (#3321) * Update README.md (#3405) * Fix/workflow tool incorrect parameter configurations (#3402) Co-authored-by: Joel * chore: replace all set interval (#3411) * feat: Deprecate datetime.utcnow() in favor of datetime.now(timezone.utc).replace(tzinfo=None) for better timezone handling (#3408) (#3416) * chore: remove Langchain tools import (#3407) * feat: gemini pro function call (#3406) * fix: shared text-generation stream (#3419) * fix/dataset-retriever-tool-parameter-redundancy (#3418) * Feat/api tool custom timeout (#3420) * fix: test env key missing or wrong (#3430) * Doc/update readme (#3433) * Update README_CN.md (#3434) * Update README_CN.md (#3435) * feat: add workflow editor shortcuts (#3382) (#3390) * FEAT: cohere rerank 3 model added (#3431) * chore: remove the COPY instruction in .devcontainer/Dockerfile (#3409) * fix typo: Changlog -> Changelog (#3442) * fix: node shortcuts active in input fields (#3438) * Add nvidia codegemma 7b (#3437) * Update yaml and py file in Tavily Tool (#3450) * feat: Added the mirror of Aliyun's Linux apk installation package and updated the deprecated taobao npm mirror address to npmmirror (#3459) * Revert "Update yaml and py file in Tavily Tool" (#3464) * feat: jina reader (#3468) * feat: support configurate openai compatible stream tool call (#3467) * feat: optimize the efficiency of generating chatbot conversation name (#3472) * feat: remove langchain from output parsers (#3473) * chore: separate Python dependencies for development (#3198) * chore: add sandbox permission tooltip (#3477) * fix: prompt template issue (#3449) * feat: support relyt vector database (#3367) Co-authored-by: jingsi * Update README.md (#3478) * nvidia-label-update (#3482) * fix: in conversation log click op button would cause close drawer (#3483) * fix: workflow auto layout nodes offset & delete node shortcuts (#3484) * fix: workflow edge curvature (#3488) * fix: stringify object while exporting batch result to csv (#3481) * question classifier prompt optimization (#3479) * feat: refactor tongyi models (#3496) * fix: bump twilio to 9.0.4 skipping yanked versions (#3500) * test: install ffmpeg for pytests (#3499) * feat: support var auto rename in prompt editor (#3510) * fix: add message caused problem after simple chat convert to workflow (#3511) * fix: the object field is empty string in some openAI api compatible model (#3506) * Add suuport for AWS Bedrock Cohere embedding (#3444) * fix: add completion mode object check (#3515) * get config default for sandbox (#3508) Co-authored-by: miendinh * chore: improve reference variable picker user experience (#3517) * fix: array[string] context in llm node invalid (#3518) * version to 0.6.3 (#3519) * fix the return with wrong datatype of segment (#3525) * fix: the hover style of the card-item operation button container (#3520) * chore: lint .env file templates (#3507) * add support for swagger object type (#3426) Co-authored-by: lipeikui * /fix register * /add modelhub * localhost * /fix req bugs * modelhub add rerank --------- Signed-off-by: Jat Signed-off-by: dependabot[bot] Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: zxhlyh Co-authored-by: KVOJJJin Co-authored-by: takatost Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: Chenhe Gu Co-authored-by: Joel Co-authored-by: Jat Co-authored-by: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Co-authored-by: Eric Wang Co-authored-by: Bowen Liang Co-authored-by: legao <837937787@qq.com> Co-authored-by: Leo Q Co-authored-by: minakokojima Co-authored-by: Nite Knite Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Lao Co-authored-by: lbm21 <313338264@qq.com> Co-authored-by: 呆萌闷油瓶 <253605712@qq.com> Co-authored-by: Kenny Co-authored-by: akou Co-authored-by: longzhihun <38651850@qq.com> Co-authored-by: LiuVaayne <10231735+vaayne@users.noreply.github.com> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: Moonlit Co-authored-by: huangbaichao Co-authored-by: junytang Co-authored-by: saga.rey Co-authored-by: chenxu9741 <1309095142@qq.com> Co-authored-by: LIU HONGWEI <1327374483@qq.com> Co-authored-by: Pascal M <11357019+perzeuss@users.noreply.github.com> Co-authored-by: Yash Parmar <82636823+Yash-1511@users.noreply.github.com> Co-authored-by: Bodhi <3882561+BodhiHu@users.noreply.github.com> Co-authored-by: Selene29 Co-authored-by: Josh Feng Co-authored-by: Richards Tu <142148415+richards199999@users.noreply.github.com> Co-authored-by: YidaHu Co-authored-by: Jingpan Xiong <71321890+klaus-xiong@users.noreply.github.com> Co-authored-by: jingsi Co-authored-by: Joshua <138381132+joshua20231026@users.noreply.github.com> Co-authored-by: sino Co-authored-by: liuzhenghua <1090179900@qq.com> Co-authored-by: kerlion <40377268+kerlion@users.noreply.github.com> Co-authored-by: miendinh <22139872+miendinh@users.noreply.github.com> Co-authored-by: miendinh Co-authored-by: buu Co-authored-by: LeePui <444561897@qq.com> Co-authored-by: lipeikui --- .devcontainer/Dockerfile | 3 - .github/workflows/api-tests.yml | 11 +- .github/workflows/build-push.yml | 2 +- .github/workflows/style.yml | 5 +- CONTRIBUTING.md | 2 +- CONTRIBUTING_CN.md | 2 +- README.md | 295 +++--- README_CN.md | 232 +++-- README_ES.md | 284 ++++-- README_FR.md | 297 ++++-- README_JA.md | 303 ++++-- README_KL.md | 287 ++++-- api/.env.example | 13 +- api/Dockerfile | 3 +- api/README.md | 14 +- api/app.py | 14 +- api/commands.py | 405 +++++--- api/config.py | 19 +- api/controllers/console/app/conversation.py | 4 +- api/controllers/console/auth/activate.py | 4 +- api/controllers/console/auth/oauth.py | 4 +- .../console/datasets/data_source.py | 4 +- .../console/datasets/datasets_document.py | 18 +- .../console/datasets/datasets_segments.py | 4 +- api/controllers/console/explore/completion.py | 6 +- .../console/explore/conversation.py | 3 +- .../console/explore/installed_app.py | 4 +- api/controllers/console/workspace/account.py | 6 +- api/controllers/service_api/app/app.py | 17 +- .../service_api/app/conversation.py | 9 +- api/controllers/service_api/wraps.py | 4 +- api/controllers/web/conversation.py | 4 +- api/core/agent/base_agent_runner.py | 86 +- api/core/agent/cot_agent_runner.py | 630 ++++-------- api/core/agent/cot_chat_agent_runner.py | 71 ++ api/core/agent/cot_completion_agent_runner.py | 69 ++ api/core/agent/entities.py | 17 + api/core/agent/fc_agent_runner.py | 145 +-- .../agent/output_parser/cot_output_parser.py | 183 ++++ .../app/apps/advanced_chat/app_generator.py | 3 + .../advanced_chat/generate_task_pipeline.py | 10 + api/core/app/apps/agent_chat/app_generator.py | 3 + api/core/app/apps/agent_chat/app_runner.py | 81 +- api/core/app/apps/chat/app_generator.py | 3 + api/core/app/apps/chat/app_runner.py | 2 + api/core/app/apps/completion/app_generator.py | 3 + api/core/app/apps/completion/app_runner.py | 2 + api/core/app/apps/workflow/app_generator.py | 3 + .../easy_ui_based_generate_task_pipeline.py | 12 + .../app/task_pipeline/message_cycle_manage.py | 59 +- .../task_pipeline/workflow_cycle_manage.py | 10 +- .../agent_tool_callback_handler.py | 28 +- api/core/docstore/dataset_docstore.py | 2 +- api/core/embedding/cached_embedding.py | 28 +- api/core/entities/message_entities.py | 108 +-- api/core/entities/provider_configuration.py | 4 +- .../helper/code_executor/code_executor.py | 7 +- .../code_executor/javascript_transformer.py | 6 +- .../helper/code_executor/jina2_transformer.py | 6 +- .../code_executor/python_transformer.py | 2 +- api/core/indexing_runner.py | 81 +- api/core/llm_generator/llm_generator.py | 3 +- .../llm_generator/output_parser/errors.py | 2 + .../output_parser/rule_config_generator.py | 5 +- .../suggested_questions_after_answer.py | 4 +- api/core/memory/token_buffer_memory.py | 15 +- .../model_providers/_position.yaml | 3 +- .../azure_openai/azure_openai.yaml | 6 + .../model_providers/azure_openai/llm/llm.py | 8 +- .../model_providers/bedrock/bedrock.yaml | 3 +- .../llm/amazon.titan-text-express-v1.yaml | 2 - .../llm/amazon.titan-text-lite-v1.yaml | 2 - .../bedrock/llm/anthropic.claude-v1.yaml | 1 + .../llm/cohere.command-light-text-v14.yaml | 2 +- .../bedrock/llm/cohere.command-text-v14.yaml | 6 +- .../model_providers/bedrock/llm/llm.py | 36 +- .../bedrock/text_embedding}/__init__.py | 0 .../bedrock/text_embedding/_position.yaml | 3 + .../amazon.titan-embed-text-v1.yaml | 8 + .../cohere.embed-english-v3.yaml | 8 + .../cohere.embed-multilingual-v3.yaml | 8 + .../bedrock/text_embedding/text_embedding.py | 234 +++++ .../model_providers/cohere/llm/_position.yaml | 2 + .../cohere/llm/command-chat.yaml | 2 +- .../cohere/llm/command-light-chat.yaml | 2 +- .../llm/command-light-nightly-chat.yaml | 2 +- .../cohere/llm/command-light-nightly.yaml | 2 +- .../cohere/llm/command-light.yaml | 2 +- .../cohere/llm/command-nightly-chat.yaml | 2 +- .../cohere/llm/command-nightly.yaml | 2 +- .../cohere/llm/command-r-plus.yaml | 45 + .../model_providers/cohere/llm/command-r.yaml | 45 + .../model_providers/cohere/llm/command.yaml | 2 +- .../model_providers/cohere/llm/llm.py | 343 +++++-- .../cohere/rerank/_position.yaml | 4 + .../cohere/rerank/rerank-english-v3.0.yaml | 4 + .../rerank/rerank-multilingual-v3.0.yaml | 4 + .../model_providers/cohere/rerank/rerank.py | 31 +- .../cohere/text_embedding/text_embedding.py | 43 +- .../google/llm/gemini-1.5-pro-latest.yaml | 39 + .../google/llm/gemini-pro.yaml | 2 + .../model_providers/google/llm/llm.py | 200 ++-- .../model_providers/modelhub}/__init__.py | 0 .../modelhub/_assets/icon_l_en.svg | 4 + .../modelhub/_assets/icon_s_en.svg | 4 + .../model_providers/modelhub/_common.py | 44 + .../modelhub/llm/Baichuan2-Turbo.yaml | 27 + .../model_providers/modelhub/llm/__init__.py | 0 .../modelhub/llm/_position.yaml | 4 + .../modelhub/llm/glm-3-turbo.yaml | 21 + .../model_providers/modelhub/llm/glm-4.yaml | 21 + .../modelhub/llm/gpt-3.5-turbo.yaml | 33 + .../model_providers/modelhub/llm/gpt-4.yaml | 33 + .../model_providers/modelhub/llm/llm.py | 782 +++++++++++++++ .../model_providers/modelhub/modelhub.py | 11 + .../model_providers/modelhub/modelhub.yaml | 152 +++ .../modelhub/rerank/__init__.py | 0 .../modelhub/rerank/bge-reranker-base.yaml | 4 + .../modelhub/rerank/bge-reranker-v2-m3.yaml | 4 + .../model_providers/modelhub/rerank/rerank.py | 100 ++ .../modelhub/text_embedding/__init__.py | 0 .../modelhub/text_embedding/bge-m3.yaml | 9 + .../modelhub/text_embedding/m3e-large.yaml | 9 + .../text-embedding-3-large.yaml | 9 + .../text-embedding-3-small.yaml | 9 + .../modelhub/text_embedding/text_embedding.py | 244 +++++ .../model_providers/moonshot/llm/llm.py | 313 +++++- .../model_providers/moonshot/moonshot.yaml | 49 + .../model_providers/nvidia/llm/_position.yaml | 1 + .../nvidia/llm/codegemma-7b.yaml | 30 + .../model_providers/nvidia/llm/llm.py | 1 + .../model_providers/nvidia/nvidia.yaml | 2 +- .../nvidia/rerank/rerank-qa-mistral-4b.yaml | 2 +- .../model_providers/ollama/ollama.yaml | 4 +- .../model_providers/openai/llm/_position.yaml | 2 + .../openai/llm/gpt-4-turbo-2024-04-09.yaml | 57 ++ .../openai/llm/gpt-4-turbo.yaml | 57 ++ .../model_providers/openai/llm/llm.py | 28 + .../openai_api_compatible/llm/llm.py | 155 ++- .../openai_api_compatible.yaml | 47 +- .../model_providers/tongyi/llm/_client.py | 82 -- .../model_providers/tongyi/llm/llm.py | 301 ++++-- .../tongyi/llm/qwen-max-0403.yaml | 81 ++ .../tongyi/llm/qwen-max-1201.yaml | 6 +- .../tongyi/llm/qwen-max-longcontext.yaml | 6 +- .../model_providers/tongyi/llm/qwen-max.yaml | 6 +- .../tongyi/llm/qwen-plus-chat.yaml | 81 ++ .../model_providers/tongyi/llm/qwen-plus.yaml | 4 +- .../tongyi/llm/qwen-turbo-chat.yaml | 81 ++ .../tongyi/llm/qwen-turbo.yaml | 4 +- .../tongyi/llm/qwen-vl-max.yaml | 47 + .../tongyi/llm/qwen-vl-plus.yaml | 47 + .../tongyi/text_embedding/text_embedding.py | 22 +- .../model_providers/tongyi/tts/tts.py | 4 +- .../triton_inference_server.yaml | 12 +- .../wenxin/llm/ernie-3.5-4k-0205.yaml | 4 +- api/core/prompt/simple_prompt_transform.py | 4 +- .../rag/datasource/keyword/jieba/jieba.py | 145 +-- .../datasource/vdb/milvus/milvus_vector.py | 20 +- api/core/rag/datasource/vdb/relyt/__init__.py | 0 .../rag/datasource/vdb/relyt/relyt_vector.py | 169 ++++ api/core/rag/datasource/vdb/vector_factory.py | 26 + api/core/rag/extractor/blod/blod.py | 2 +- api/core/rag/extractor/csv_extractor.py | 1 + api/core/rag/extractor/excel_extractor.py | 47 +- api/core/rag/extractor/extract_processor.py | 9 +- .../unstructured_epub_extractor.py | 37 + api/core/rag/retrieval/agent/fake_llm.py | 59 -- api/core/rag/retrieval/agent/llm_chain.py | 46 - .../agent/multi_dataset_router_agent.py | 179 ---- .../agent/output_parser/structured_chat.py | 29 - .../structed_multi_dataset_router_agent.py | 259 ----- .../retrieval/agent_based_dataset_executor.py | 117 --- api/core/rag/retrieval/dataset_retrieval.py | 343 ++++++- .../rag/retrieval/output_parser/__init__.py | 0 .../retrieval/output_parser/react_output.py | 25 + .../output_parser/structured_chat.py | 25 + .../multi_dataset_function_call_router.py | 0 .../router}/multi_dataset_react_route.py | 57 +- api/core/tools/prompt/template.py | 4 +- api/core/tools/provider/_position.yaml | 2 + .../builtin/arxiv/tools/arxiv_search.py | 85 +- .../builtin/bing/tools/bing_web_search.py | 5 +- .../builtin/brave/tools/brave_search.py | 90 +- .../duckduckgo/tools/duckduckgo_search.py | 139 ++- .../builtin/google/tools/google_search.py | 37 +- .../provider/builtin/jina/_assets/icon.svg | 4 + api/core/tools/provider/builtin/jina/jina.py | 12 + .../tools/provider/builtin/jina/jina.yaml | 13 + .../builtin/jina/tools/jina_reader.py | 35 + .../builtin/jina/tools/jina_reader.yaml | 41 + .../builtin/pubmed/tools/pubmed_search.py | 177 +++- .../provider/builtin/searxng/_assets/icon.svg | 56 ++ .../tools/provider/builtin/searxng/searxng.py | 25 + .../provider/builtin/searxng/searxng.yaml | 24 + .../builtin/searxng/tools/searxng_search.py | 124 +++ .../builtin/searxng/tools/searxng_search.yaml | 89 ++ .../builtin/twilio/tools/send_message.py | 63 +- .../wikipedia/tools/wikipedia_search.py | 81 +- api/core/tools/tool/api_tool.py | 16 +- .../dataset_multi_retriever_tool.py | 19 +- .../dataset_retriever_base_tool.py | 34 + .../dataset_retriever_tool.py | 21 +- api/core/tools/tool/dataset_retriever_tool.py | 19 +- api/core/tools/tool/tool.py | 17 +- api/core/tools/tool_engine.py | 2 +- api/core/workflow/entities/variable_pool.py | 4 +- api/core/workflow/nodes/code/code_node.py | 3 + .../knowledge_retrieval_node.py | 237 +---- api/core/workflow/nodes/llm/llm_node.py | 35 +- .../question_classifier_node.py | 22 +- .../question_classifier/template_prompts.py | 6 +- api/core/workflow/nodes/tool/entities.py | 31 +- api/core/workflow/workflow_engine_manager.py | 40 +- api/events/event_handlers/__init__.py | 1 - .../event_handlers/create_document_index.py | 2 +- ...rsation_name_when_first_message_created.py | 32 - ...vider_last_used_at_when_messaeg_created.py | 4 +- api/extensions/ext_celery.py | 1 + api/extensions/ext_storage.py | 4 +- api/libs/json_in_md_parser.py | 2 +- api/models/model.py | 2 +- api/models/task.py | 8 +- api/models/workflow.py | 4 + api/requirements-dev.txt | 4 + api/requirements.txt | 32 +- api/services/account_service.py | 16 +- api/services/annotation_service.py | 2 +- api/services/app_service.py | 12 +- api/services/conversation_service.py | 12 +- api/services/dataset_service.py | 180 ++-- api/services/file_service.py | 13 +- api/services/vector_service.py | 2 +- api/services/web_conversation_service.py | 7 +- api/services/workflow_service.py | 18 +- api/tasks/add_document_to_index_task.py | 2 +- .../enable_annotation_reply_task.py | 2 +- .../batch_create_segment_to_index_task.py | 4 +- api/tasks/clean_dataset_task.py | 16 +- api/tasks/create_segment_to_index_task.py | 6 +- api/tasks/document_indexing_sync_task.py | 2 +- api/tasks/document_indexing_task.py | 4 +- api/tasks/document_indexing_update_task.py | 2 +- api/tasks/enable_segment_to_index_task.py | 2 +- api/tests/integration_tests/.env.example | 5 +- .../model_runtime/__mock/google.py | 11 +- api/tests/unit_tests/core/rag/__init__.py | 0 .../core/rag/datasource/__init__.py | 0 .../core/rag/datasource/vdb/__init__.py | 0 .../rag/datasource/vdb/milvus/__init__.py | 0 .../rag/datasource/vdb/milvus/test_milvus.py | 24 + dev/reformat | 8 + docker/docker-compose.middleware.yaml | 44 +- docker/docker-compose.yaml | 14 +- web/Dockerfile | 7 +- web/app/components/app-sidebar/app-info.tsx | 2 +- web/app/components/app/chat/thought/tool.tsx | 6 +- .../config-var/config-modal/index.tsx | 2 +- .../app/configuration/config/index.tsx | 5 +- .../dataset-config/card-item/item.tsx | 2 +- .../debug-with-multiple-model/chat-item.tsx | 2 +- .../debug-with-multiple-model/debug-item.tsx | 2 +- .../debug/debug-with-multiple-model/index.tsx | 5 +- .../debug/debug-with-single-model/index.tsx | 2 +- .../hooks/use-advanced-prompt-config.ts | 4 +- .../components/app/create-app-modal/index.tsx | 2 +- web/app/components/base/audio-btn/index.tsx | 111 ++- .../base/audio-btn/style.module.css | 2 +- .../base/auto-height-textarea/index.tsx | 12 +- .../base/chat/chat/answer/index.tsx | 4 +- web/app/components/base/chat/chat/hooks.ts | 2 +- web/app/components/base/chat/chat/index.tsx | 7 +- .../components/base/prompt-editor/index.tsx | 2 + .../plugins/on-blur-or-focus-block.tsx | 2 +- .../plugins/variable-value-block/index.tsx | 2 +- .../plugins/variable-value-block/utils.ts | 2 +- web/app/components/base/tag-input/index.tsx | 4 +- .../create/embedding-process/index.tsx | 45 +- .../documents/detail/embedding/index.tsx | 57 +- .../invited-modal/invitation-link.tsx | 2 +- .../model-provider-page/declarations.ts | 16 +- .../model-provider-page/hooks.ts | 4 +- .../model-parameter-modal/index.tsx | 6 +- .../share/text-generation/index.tsx | 7 +- .../share/text-generation/result/index.tsx | 23 +- .../workflow/custom-connection-line.tsx | 5 +- web/app/components/workflow/custom-edge.tsx | 5 +- .../workflow/hooks/use-nodes-interactions.ts | 138 ++- .../workflow/hooks/use-workflow-run.ts | 11 + .../components/workflow/hooks/use-workflow.ts | 39 +- web/app/components/workflow/index.tsx | 20 +- .../components/before-run-form/index.tsx | 14 +- .../nodes/_base/components/editor/base.tsx | 2 +- .../nodes/_base/components/variable/utils.ts | 80 +- .../variable/var-reference-vars.tsx | 24 +- .../components/workflow/nodes/end/node.tsx | 3 +- .../nodes/knowledge-retrieval/node.tsx | 11 +- .../llm/components/config-prompt-item.tsx | 107 +++ .../nodes/llm/components/config-prompt.tsx | 59 +- .../llm/components/resolution-picker.tsx | 4 +- .../components/workflow/nodes/llm/panel.tsx | 19 +- .../workflow/nodes/llm/use-config.ts | 55 +- .../workflow/nodes/start/use-config.ts | 8 +- .../workflow/nodes/tool/use-config.ts | 37 +- .../components/var-list/index.tsx | 2 + .../workflow/nodes/variable-assigner/node.tsx | 4 +- .../workflow/panel/chat-record/index.tsx | 8 +- .../panel/debug-and-preview/chat-wrapper.tsx | 2 +- .../workflow/panel/inputs-panel.tsx | 10 +- web/app/components/workflow/run/node.tsx | 6 +- .../components/workflow/run/result-panel.tsx | 8 +- web/app/components/workflow/store.ts | 8 + web/app/components/workflow/utils.ts | 86 +- web/i18n/de-DE/app-annotation.ts | 87 ++ web/i18n/de-DE/app-api.ts | 82 ++ web/i18n/de-DE/app-debug.ts | 409 ++++++++ web/i18n/de-DE/app-log.ts | 69 ++ web/i18n/de-DE/app-overview.ts | 139 +++ web/i18n/de-DE/app.ts | 54 ++ web/i18n/de-DE/billing.ts | 115 +++ web/i18n/de-DE/common.ts | 505 ++++++++++ web/i18n/de-DE/custom.ts | 30 + web/i18n/de-DE/dataset-creation.ts | 130 +++ web/i18n/de-DE/dataset-documents.ts | 349 +++++++ web/i18n/de-DE/dataset-hit-testing.ts | 28 + web/i18n/de-DE/dataset-settings.ts | 33 + web/i18n/de-DE/dataset.ts | 47 + web/i18n/de-DE/explore.ts | 41 + web/i18n/de-DE/layout.ts | 4 + web/i18n/de-DE/login.ts | 59 ++ web/i18n/de-DE/register.ts | 4 + web/i18n/de-DE/run-log.ts | 23 + web/i18n/de-DE/share-app.ts | 74 ++ web/i18n/de-DE/tools.ts | 115 +++ web/i18n/de-DE/workflow.ts | 333 +++++++ web/i18n/en-US/app-api.ts | 1 + web/i18n/en-US/common.ts | 2 +- web/i18n/en-US/workflow.ts | 3 + web/i18n/fr-FR/app-debug.ts | 39 +- web/i18n/fr-FR/app-log.ts | 64 +- web/i18n/fr-FR/app-overview.ts | 100 +- web/i18n/fr-FR/app.ts | 80 +- web/i18n/fr-FR/common.ts | 5 + web/i18n/fr-FR/explore.ts | 2 +- web/i18n/fr-FR/run-log.ts | 20 +- web/i18n/fr-FR/workflow.ts | 335 ++++++- web/i18n/ja-JP/app-debug.ts | 40 +- web/i18n/ja-JP/app-log.ts | 40 +- web/i18n/ja-JP/app-overview.ts | 102 +- web/i18n/ja-JP/app.ts | 88 +- web/i18n/ja-JP/common.ts | 5 + web/i18n/ja-JP/explore.ts | 4 +- web/i18n/ja-JP/run-log.ts | 20 +- web/i18n/ja-JP/workflow.ts | 333 ++++++- web/i18n/language.ts | 27 +- web/i18n/pt-BR/app-debug.ts | 20 +- web/i18n/pt-BR/app-log.ts | 46 +- web/i18n/pt-BR/app-overview.ts | 102 +- web/i18n/pt-BR/app.ts | 82 +- web/i18n/pt-BR/common.ts | 4 + web/i18n/pt-BR/explore.ts | 2 +- web/i18n/pt-BR/run-log.ts | 20 +- web/i18n/pt-BR/workflow.ts | 343 ++++++- web/i18n/uk-UA/app-debug.ts | 40 +- web/i18n/uk-UA/app-log.ts | 102 +- web/i18n/uk-UA/app-overview.ts | 158 ++-- web/i18n/uk-UA/app.ts | 80 +- web/i18n/uk-UA/common.ts | 4 + web/i18n/uk-UA/explore.ts | 2 +- web/i18n/uk-UA/run-log.ts | 22 +- web/i18n/uk-UA/workflow.ts | 345 ++++++- web/i18n/vi-VN/app-debug.ts | 39 +- web/i18n/vi-VN/app-log.ts | 68 +- web/i18n/vi-VN/app-overview.ts | 90 +- web/i18n/vi-VN/app.ts | 78 +- web/i18n/vi-VN/common.ts | 5 + web/i18n/vi-VN/explore.ts | 2 +- web/i18n/vi-VN/run-log.ts | 20 +- web/i18n/vi-VN/workflow.ts | 335 ++++++- web/i18n/zh-Hans/app-api.ts | 1 + web/i18n/zh-Hans/workflow.ts | 3 + web/package.json | 6 +- web/utils/model-config.ts | 3 + web/yarn.lock | 894 ++++++++++++------ 384 files changed, 15369 insertions(+), 4571 deletions(-) create mode 100644 api/core/agent/cot_chat_agent_runner.py create mode 100644 api/core/agent/cot_completion_agent_runner.py create mode 100644 api/core/agent/output_parser/cot_output_parser.py create mode 100644 api/core/llm_generator/output_parser/errors.py rename api/core/{rag/retrieval/agent => model_runtime/model_providers/bedrock/text_embedding}/__init__.py (100%) create mode 100644 api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml create mode 100644 api/core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v1.yaml create mode 100644 api/core/model_runtime/model_providers/bedrock/text_embedding/cohere.embed-english-v3.yaml create mode 100644 api/core/model_runtime/model_providers/bedrock/text_embedding/cohere.embed-multilingual-v3.yaml create mode 100644 api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py create mode 100644 api/core/model_runtime/model_providers/cohere/llm/command-r-plus.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/llm/command-r.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/rerank/_position.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/rerank/rerank-english-v3.0.yaml create mode 100644 api/core/model_runtime/model_providers/cohere/rerank/rerank-multilingual-v3.0.yaml create mode 100644 api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml rename api/core/{rag/retrieval/agent/output_parser => model_runtime/model_providers/modelhub}/__init__.py (100%) create mode 100644 api/core/model_runtime/model_providers/modelhub/_assets/icon_l_en.svg create mode 100644 api/core/model_runtime/model_providers/modelhub/_assets/icon_s_en.svg create mode 100644 api/core/model_runtime/model_providers/modelhub/_common.py create mode 100644 api/core/model_runtime/model_providers/modelhub/llm/Baichuan2-Turbo.yaml create mode 100644 api/core/model_runtime/model_providers/modelhub/llm/__init__.py create mode 100644 api/core/model_runtime/model_providers/modelhub/llm/_position.yaml create mode 100644 api/core/model_runtime/model_providers/modelhub/llm/glm-3-turbo.yaml create mode 100644 api/core/model_runtime/model_providers/modelhub/llm/glm-4.yaml create mode 100644 api/core/model_runtime/model_providers/modelhub/llm/gpt-3.5-turbo.yaml create mode 100644 api/core/model_runtime/model_providers/modelhub/llm/gpt-4.yaml create mode 100644 api/core/model_runtime/model_providers/modelhub/llm/llm.py create mode 100644 api/core/model_runtime/model_providers/modelhub/modelhub.py create mode 100644 api/core/model_runtime/model_providers/modelhub/modelhub.yaml create mode 100644 api/core/model_runtime/model_providers/modelhub/rerank/__init__.py create mode 100644 api/core/model_runtime/model_providers/modelhub/rerank/bge-reranker-base.yaml create mode 100644 api/core/model_runtime/model_providers/modelhub/rerank/bge-reranker-v2-m3.yaml create mode 100644 api/core/model_runtime/model_providers/modelhub/rerank/rerank.py create mode 100644 api/core/model_runtime/model_providers/modelhub/text_embedding/__init__.py create mode 100644 api/core/model_runtime/model_providers/modelhub/text_embedding/bge-m3.yaml create mode 100644 api/core/model_runtime/model_providers/modelhub/text_embedding/m3e-large.yaml create mode 100644 api/core/model_runtime/model_providers/modelhub/text_embedding/text-embedding-3-large.yaml create mode 100644 api/core/model_runtime/model_providers/modelhub/text_embedding/text-embedding-3-small.yaml create mode 100644 api/core/model_runtime/model_providers/modelhub/text_embedding/text_embedding.py create mode 100644 api/core/model_runtime/model_providers/nvidia/llm/codegemma-7b.yaml create mode 100644 api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-2024-04-09.yaml create mode 100644 api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo.yaml delete mode 100644 api/core/model_runtime/model_providers/tongyi/llm/_client.py create mode 100644 api/core/model_runtime/model_providers/tongyi/llm/qwen-max-0403.yaml create mode 100644 api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-chat.yaml create mode 100644 api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo-chat.yaml create mode 100644 api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max.yaml create mode 100644 api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus.yaml create mode 100644 api/core/rag/datasource/vdb/relyt/__init__.py create mode 100644 api/core/rag/datasource/vdb/relyt/relyt_vector.py create mode 100644 api/core/rag/extractor/unstructured/unstructured_epub_extractor.py delete mode 100644 api/core/rag/retrieval/agent/fake_llm.py delete mode 100644 api/core/rag/retrieval/agent/llm_chain.py delete mode 100644 api/core/rag/retrieval/agent/multi_dataset_router_agent.py delete mode 100644 api/core/rag/retrieval/agent/output_parser/structured_chat.py delete mode 100644 api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py delete mode 100644 api/core/rag/retrieval/agent_based_dataset_executor.py create mode 100644 api/core/rag/retrieval/output_parser/__init__.py create mode 100644 api/core/rag/retrieval/output_parser/react_output.py create mode 100644 api/core/rag/retrieval/output_parser/structured_chat.py rename api/core/{workflow/nodes/knowledge_retrieval => rag/retrieval/router}/multi_dataset_function_call_router.py (100%) rename api/core/{workflow/nodes/knowledge_retrieval => rag/retrieval/router}/multi_dataset_react_route.py (79%) create mode 100644 api/core/tools/provider/builtin/jina/_assets/icon.svg create mode 100644 api/core/tools/provider/builtin/jina/jina.py create mode 100644 api/core/tools/provider/builtin/jina/jina.yaml create mode 100644 api/core/tools/provider/builtin/jina/tools/jina_reader.py create mode 100644 api/core/tools/provider/builtin/jina/tools/jina_reader.yaml create mode 100644 api/core/tools/provider/builtin/searxng/_assets/icon.svg create mode 100644 api/core/tools/provider/builtin/searxng/searxng.py create mode 100644 api/core/tools/provider/builtin/searxng/searxng.yaml create mode 100644 api/core/tools/provider/builtin/searxng/tools/searxng_search.py create mode 100644 api/core/tools/provider/builtin/searxng/tools/searxng_search.yaml create mode 100644 api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py delete mode 100644 api/events/event_handlers/generate_conversation_name_when_first_message_created.py create mode 100644 api/requirements-dev.txt create mode 100644 api/tests/unit_tests/core/rag/__init__.py create mode 100644 api/tests/unit_tests/core/rag/datasource/__init__.py create mode 100644 api/tests/unit_tests/core/rag/datasource/vdb/__init__.py create mode 100644 api/tests/unit_tests/core/rag/datasource/vdb/milvus/__init__.py create mode 100644 api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py create mode 100644 web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx create mode 100644 web/i18n/de-DE/app-annotation.ts create mode 100644 web/i18n/de-DE/app-api.ts create mode 100644 web/i18n/de-DE/app-debug.ts create mode 100644 web/i18n/de-DE/app-log.ts create mode 100644 web/i18n/de-DE/app-overview.ts create mode 100644 web/i18n/de-DE/app.ts create mode 100644 web/i18n/de-DE/billing.ts create mode 100644 web/i18n/de-DE/common.ts create mode 100644 web/i18n/de-DE/custom.ts create mode 100644 web/i18n/de-DE/dataset-creation.ts create mode 100644 web/i18n/de-DE/dataset-documents.ts create mode 100644 web/i18n/de-DE/dataset-hit-testing.ts create mode 100644 web/i18n/de-DE/dataset-settings.ts create mode 100644 web/i18n/de-DE/dataset.ts create mode 100644 web/i18n/de-DE/explore.ts create mode 100644 web/i18n/de-DE/layout.ts create mode 100644 web/i18n/de-DE/login.ts create mode 100644 web/i18n/de-DE/register.ts create mode 100644 web/i18n/de-DE/run-log.ts create mode 100644 web/i18n/de-DE/share-app.ts create mode 100644 web/i18n/de-DE/tools.ts create mode 100644 web/i18n/de-DE/workflow.ts diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index e1070e4ed3f37..ab585a5ae9361 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,8 +1,5 @@ FROM mcr.microsoft.com/devcontainers/python:3.10 -COPY . . - - # [Optional] Uncomment this section to install additional OS packages. # RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ # && apt-get -y install --no-install-recommends \ No newline at end of file diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index c028cc4886873..3a4d1fe2ea913 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -32,15 +32,22 @@ jobs: - name: Checkout code uses: actions/checkout@v4 + - name: Install APT packages + uses: awalsh128/cache-apt-pkgs-action@v1 + with: + packages: ffmpeg + - name: Set up Python uses: actions/setup-python@v5 with: python-version: '3.10' cache: 'pip' - cache-dependency-path: ./api/requirements.txt + cache-dependency-path: | + ./api/requirements.txt + ./api/requirements-dev.txt - name: Install dependencies - run: pip install -r ./api/requirements.txt + run: pip install -r ./api/requirements.txt -r ./api/requirements-dev.txt - name: Run ModelRuntime run: pytest api/tests/integration_tests/model_runtime/anthropic api/tests/integration_tests/model_runtime/azure_openai api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index a6cfd6f16369f..048f4cd942632 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -46,7 +46,7 @@ jobs: with: images: ${{ env[matrix.image_name_env] }} tags: | - type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' && startsWith(github.ref, 'refs/tags/') }} + type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }} type=ref,event=branch type=sha,enable=true,priority=100,prefix=,suffix=,format=long type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }} diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index c704ac1f7c5ae..bdbc22b489b78 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -24,11 +24,14 @@ jobs: python-version: '3.10' - name: Python dependencies - run: pip install ruff + run: pip install ruff dotenv-linter - name: Ruff check run: ruff check ./api + - name: Dotenv check + run: dotenv-linter ./api/.env.example ./web/.env.example + - name: Lint hints if: failure() run: echo "Please run 'dev/reformat' to fix the fixable linting errors." diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 992126551cd34..e39f221382d26 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -36,7 +36,7 @@ In terms of licensing, please take a minute to read our short [License and Contr | Feature Type | Priority | | ------------------------------------------------------------ | --------------- | | High-Priority Features as being labeled by a team member | High Priority | - | Popular feature requests from our [community feedback board](https://feedback.dify.ai/) | Medium Priority | + | Popular feature requests from our [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Medium Priority | | Non-core features and minor enhancements | Low Priority | | Valuable but not immediate | Future-Feature | diff --git a/CONTRIBUTING_CN.md b/CONTRIBUTING_CN.md index 08c5a0a4bd2d5..950ea19b255fc 100644 --- a/CONTRIBUTING_CN.md +++ b/CONTRIBUTING_CN.md @@ -34,7 +34,7 @@ | Feature Type | Priority | | ------------------------------------------------------------ | --------------- | | High-Priority Features as being labeled by a team member | High Priority | - | Popular feature requests from our [community feedback board](https://feedback.dify.ai/) | Medium Priority | + | Popular feature requests from our [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Medium Priority | | Non-core features and minor enhancements | Low Priority | | Valuable but not immediate | Future-Feature | diff --git a/README.md b/README.md index eda3b759088f0..72c673326b90d 100644 --- a/README.md +++ b/README.md @@ -1,95 +1,176 @@ -[![](./images/GitHub_README_cover.png)](https://dify.ai) +![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +

- English | - 简体中文 | - 日本語 | - Español | - Klingon | - Français + Dify Cloud · + Self-hosting · + Documentation · + Enterprise inquiry

- Static Badge + Static Badge + + Static Badge - chat on Discord - follow on Twitter - Docker Pulls + Docker Pulls + + Commits last month + + Issues closed + + Discussion posts

- - 📌 Check out Dify Premium on AWS and deploy it to your own AWS VPC with one-click. - + Commits last month + Commits last month + Commits last month + Commits last month + Commits last month + Commits last month

-**Dify** is an open-source LLM app development platform. Dify's intuitive interface combines a RAG pipeline, AI workflow orchestration, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. - -![](./images/demo.png) - - - -## Using our Cloud Services - -You can try out [Dify.AI Cloud](https://dify.ai) now. It provides all the capabilities of the self-deployed version, and includes 200 free requests to OpenAI GPT-3.5. - -### Looking to purchase via AWS? -Check out [Dify Premium on AWS](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one-click. - -## Dify vs. LangChain vs. Assistants API - -| Feature | Dify.AI | Assistants API | LangChain | -|---------|---------|----------------|-----------| -| **Programming Approach** | API-oriented | API-oriented | Python Code-oriented | -| **Ecosystem Strategy** | Open Source | Close Source | Open Source | -| **RAG Engine** | Supported | Supported | Not Supported | -| **Prompt IDE** | Included | Included | None | -| **Supported LLMs** | Rich Variety | OpenAI-only | Rich Variety | -| **Local Deployment** | Supported | Not Supported | Not Applicable | - - - -## Features - -![](./images/models.png) - -**1. LLM Support**: Integration with OpenAI's GPT family of models, or the open-source Llama2 family models. In fact, Dify supports mainstream commercial models and open-source models (locally deployed or based on MaaS). - -**2. Prompt IDE**: Visual orchestration of applications and services based on LLMs with your team. - -**3. RAG Engine**: Includes various RAG capabilities based on full-text indexing or vector database embeddings, allowing direct upload of PDFs, TXTs, and other text formats. - -**4. AI Agent**: Based on Function Calling and ReAct, the Agent inference framework allows users to customize tools, what you see is what you get. Dify provides more than a dozen built-in tool calling capabilities, such as Google Search, DELL·E, Stable Diffusion, WolframAlpha, etc. - - -**5. Continuous Operations**: Monitor and analyze application logs and performance, continuously improving Prompts, datasets, or models using production data. - -## Before You Start - -**Star us on GitHub, and be instantly notified for new releases!** +# -![star-us](https://github.com/langgenius/dify/assets/100913391/95f37259-7370-4456-a9f0-0bc01ef8642f) - -- [Website](https://dify.ai) -- [Docs](https://docs.dify.ai) -- [Deployment Docs](https://docs.dify.ai/getting-started/install-self-hosted) -- [FAQ](https://docs.dify.ai/getting-started/faq) - - -## Install the Community Edition - -### System Requirements - -Before installing Dify, make sure your machine meets the following minimum system requirements: - -- CPU >= 2 Core -- RAM >= 4GB - -### Quick Start +

+ langgenius%2Fdify | Trendshift +

+Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features: +

+ +**1. Workflow**: + Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond. + + + https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa + + + +**2. Comprehensive model support**: + Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama2, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers). + +![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) + + +**3. Prompt IDE**: + Intuitive interface for crafting prompts, comparing model performance, and adding additional features such as text-to-speech to a chat-based app. + +**4. RAG Pipeline**: + Extensive RAG capabilities that cover everything from document ingestion to retrieval, with out-of-box support for text extraction from PDFs, PPTs, and other common document formats. + +**5. Agent capabilities**: + You can define agents based on LLM Function Calling or ReAct, and add pre-built or custom tools for the agent. Dify provides 50+ built-in tools for AI agents, such as Google Search, DELL·E, Stable Diffusion and WolframAlpha. + +**6. LLMOps**: + Monitor and analyze application logs and performance over time. You could continuously improve prompts, datasets, and models based on production data and annotations. + +**7. Backend-as-a-Service**: + All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic. + + +## Feature comparison + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FeatureDify.AILangChainFlowiseOpenAI Assistants API
Programming ApproachAPI + App-orientedPython CodeApp-orientedAPI-oriented
Supported LLMsRich VarietyRich VarietyRich VarietyOpenAI-only
RAG Engine
Agent
Workflow
Observability
Enterprise Feature (SSO/Access control)
Local Deployment
+ +## Using Dify + +- **Cloud
** +We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan. + +- **Self-hosting Dify Community Edition
** +Quickly get Dify running in your environment with this [starter guide](#quick-start). +Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions. + +- **Dify for enterprise / organizations
** +We provide additional enterprise-centric features. [Schedule a meeting with us](https://cal.com/guchenhe/30min) or [send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs.
+ > For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one-click. It's an affordable AMI offering with the option to create apps with custom logo and branding. + + +## Staying ahead + +Star Dify on GitHub and be instantly notified of new releases. + +![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) + + + +## Quick start +> Before installing Dify, make sure your machine meets the following minimum system requirements: +> +>- CPU >= 2 Core +>- RAM >= 4GB + +
The easiest way to start the Dify server is to run our [docker-compose.yml](docker/docker-compose.yaml) file. Before running the installation command, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: @@ -98,63 +179,65 @@ cd docker docker compose up -d ``` -After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization installation process. - -#### Deploy with Helm Chart +After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process. -[Helm Chart](https://helm.sh/) version, which allows Dify to be deployed on Kubernetes. +> If you'd like to contribute to Dify or do additional development, refer to our [guide to deploying from source code](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code) -- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) -- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) - -### Configuration +## Next steps -If you need to customize the configuration, please refer to the comments in our [docker-compose.yml](docker/docker-compose.yaml) file and manually set the environment configuration. After making the changes, please run `docker-compose up -d` again. You can see the full list of environment variables in our [docs](https://docs.dify.ai/getting-started/install-self-hosted/environments). +If you need to customize the configuration, please refer to the comments in our [docker-compose.yml](docker/docker-compose.yaml) file and manually set the environment configuration. After making the changes, please run `docker-compose up -d` again. You can see the full list of environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). +If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) which allow Dify to be deployed on Kubernetes. -## Star History +- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) +- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) -[![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) ## Contributing For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). - At the same time, please consider supporting Dify by sharing it on social media and at events and conferences. -### Projects made by community -- [Chatbot Chrome Extension by @charli117](https://github.com/langgenius/chatbot-chrome-extension) +> We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c). -### Contributors +**Contributors** -### Translations - -We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c). +## Community & contact -## Community & Support - -* [Github Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and checking out our feature roadmap. +* [Github Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and asking questions. * [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). -* [Email Support](mailto:hello@dify.ai?subject=[GitHub]Questions%20About%20Dify). Best for: questions you have about using Dify.AI. +* [Email](mailto:support@dify.ai?subject=[GitHub]Questions%20About%20Dify). Best for: questions you have about using Dify.AI. * [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community. * [Twitter](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community. -* [Business Contact](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry). Best for: business inquiries of licensing Dify.AI for commercial use. -### Direct Meetings +Or, schedule a meeting directly with a team member: + + + + + + + + + + + + + + +
Point of ContactPurpose
Git-Hub-README-Button-3xBusiness enquiries & product feedback
Git-Hub-README-Button-2xContributions, issues & feature requests
+ +## Star history -**Help us make Dify better. Reach out directly to us**. +[![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) -| Point of Contact | Purpose | -| :----------------------------------------------------------: | :----------------------------------------------------------: | -| Git-Hub-README-Button-3x | Product design feedback, user experience discussions, feature planning and roadmaps. | -| Git-Hub-README-Button-2x | Technical support, issues, or feature requests | -## Security Disclosure +## Security disclosure To protect your privacy, please avoid posting security issues on GitHub. Instead, send your questions to security@dify.ai and we will provide you with a more detailed answer. diff --git a/README_CN.md b/README_CN.md index d34652b80e855..08fec3a056fe1 100644 --- a/README_CN.md +++ b/README_CN.md @@ -1,78 +1,167 @@ -[![](./images/describe.png)](https://dify.ai) -

- English | - 简体中文 | - 日本語 | - Español | - Klingon | - Français -

+![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) + +

- Static Badge + Static Badge + + Static Badge - chat on Discord - follow on Twitter - Docker Pulls -

- -

- - Dify 发布 AI Agent 能力:基于不同的大型语言模型构建 GPTs 和 Assistants - + Docker Pulls + + Commits last month + + Issues closed + + Discussion posts

-Dify 是一个 LLM 应用开发平台,已经有超过 10 万个应用基于 Dify.AI 构建。它融合了 Backend as Service 和 LLMOps 的理念,涵盖了构建生成式 AI 原生应用所需的核心技术栈,包括一个内置 RAG 引擎。使用 Dify,你可以基于任何模型自部署类似 Assistants API 和 GPTs 的能力。 - -![](./images/demo.png) - -## 使用云端服务 - -使用 [Dify.AI Cloud](https://dify.ai) 提供开源版本的所有功能,并包含 200 次 GPT 试用额度。 - -## 为什么选择 Dify - -Dify 具有模型中立性,相较 LangChain 等硬编码开发库 Dify 是一个完整的、工程化的技术栈,而相较于 OpenAI 的 Assistants API 你可以完全将服务部署在本地。 - -| 功能 | Dify.AI | Assistants API | LangChain | -| --- | --- | --- | --- | -| 编程方式 | 面向 API | 面向 API | 面向 Python 代码 | -| 生态策略 | 开源 | 封闭且商用 | 开源 | -| RAG 引擎 | 支持 | 支持 | 不支持 | -| Prompt IDE | 包含 | 包含 | 没有 | -| 支持的 LLMs | 丰富 | 仅 GPT | 丰富 | -| 本地部署 | 支持 | 不支持 | 不适用 | - - -## 特点 - -![](./images/models.png) - -**1. LLM支持**:与 OpenAI 的 GPT 系列模型集成,或者与开源的 Llama2 系列模型集成。事实上,Dify支持主流的商业模型和开源模型(本地部署或基于 MaaS)。 - -**2. Prompt IDE**:和团队一起在 Dify 协作,通过可视化的 Prompt 和应用编排工具开发 AI 应用。 支持无缝切换多种大型语言模型。 - -**3. RAG引擎**:包括各种基于全文索引或向量数据库嵌入的 RAG 能力,允许直接上传 PDF、TXT 等各种文本格式。 - -**4. AI Agent**:基于 Function Calling 和 ReAct 的 Agent 推理框架,允许用户自定义工具,所见即所得。Dify 提供了十多种内置工具调用能力,如谷歌搜索、DELL·E、Stable Diffusion、WolframAlpha 等。 - -**5. 持续运营**:监控和分析应用日志和性能,使用生产数据持续改进 Prompt、数据集或模型。 - -## 在开始之前 - -**关注我们,您将立即收到 GitHub 上所有新发布版本的通知!** - -![star-us](https://github.com/langgenius/dify/assets/100913391/95f37259-7370-4456-a9f0-0bc01ef8642f) - -- [网站](https://dify.ai) -- [文档](https://docs.dify.ai) -- [部署文档](https://docs.dify.ai/getting-started/install-self-hosted) -- [常见问题](https://docs.dify.ai/getting-started/faq) +
+ 上个月的提交次数 + 上个月的提交次数 + 上个月的提交次数 + 上个月的提交次数 + 上个月的提交次数 + 上个月的提交次数 +
+ + +# + +
+ langgenius%2Fdify | 趋势转变 +
+ +Dify 是一个开源的LLM应用开发平台。其直观的界面结合了AI工作流程、RAG管道、代理功能、模型管理、可观察性功能等,让您可以快速从原型到生产。以下是其核心功能列表: +

+ +**1. 工作流**: + 在视觉画布上构建和测试功能强大的AI工作流程,利用以下所有功能以及更多功能。 + + + https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa + + + +**2. 全面的模型支持**: + 与数百种专有/开源LLMs以及数十种推理提供商和自托管解决方案无缝集成,涵盖GPT、Mistral、Llama2以及任何与OpenAI API兼容的模型。完整的支持模型提供商列表可在[此处](https://docs.dify.ai/getting-started/readme/model-providers)找到。 + +![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) + + +**3. Prompt IDE**: + 用于制作提示、比较模型性能以及向基于聊天的应用程序添加其他功能(如文本转语音)的直观界面。 + +**4. RAG Pipeline**: + 广泛的RAG功能,涵盖从文档摄入到检索的所有内容,支持从PDF、PPT和其他常见文档格式中提取文本的开箱即用的支持。 + +**5. Agent 智能体**: + 您可以基于LLM函数调用或ReAct定义代理,并为代理添加预构建或自定义工具。Dify为AI代理提供了50多种内置工具,如谷歌搜索、DELL·E、稳定扩散和WolframAlpha等。 + +**6. LLMOps**: + 随时间监视和分析应用程序日志和性能。您可以根据生产数据和注释持续改进提示、数据集和模型。 + +**7. 后端即服务**: + 所有Dify的功能都带有相应的API,因此您可以轻松地将Dify集成到自己的业务逻辑中。 + + +## 功能比较 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
功能Dify.AILangChainFlowiseOpenAI助理API
编程方法API + 应用程序导向Python代码应用程序导向API导向
支持的LLMs丰富多样丰富多样丰富多样仅限OpenAI
RAG引擎
代理
工作流程
可观察性
企业功能(SSO/访问控制)
本地部署
+ +## 使用 Dify + +- **云
** +我们提供[ Dify 云服务](https://dify.ai),任何人都可以零设置尝试。它提供了自部署版本的所有功能,并在沙盒计划中包含 200 次免费的 GPT-4 调用。 + +- **自托管 Dify 社区版
** +使用这个[入门指南](#quick-start)快速在您的环境中运行 Dify。 +使用我们的[文档](https://docs.dify.ai)进行进一步的参考和更深入的说明。 + +- **面向企业/组织的 Dify
** +我们提供额外的面向企业的功能。[与我们安排会议](https://cal.com/guchenhe/30min)或[给我们发送电子邮件](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)讨论企业需求。
+ > 对于使用 AWS 的初创公司和中小型企业,请查看 [AWS Marketplace 上的 Dify 高级版](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6),并使用一键部署到您自己的 AWS VPC。它是一个价格实惠的 AMI 产品,提供了使用自定义徽标和品牌创建应用程序的选项。 + +## 保持领先 + +在 GitHub 上给 Dify Star,并立即收到新版本的通知。 + +![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) ## 安装社区版 @@ -110,6 +199,19 @@ docker compose up -d [![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) +## Contributing + +对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。 +同时,请考虑通过社交媒体、活动和会议来支持Dify的分享。 + +> 我们正在寻找贡献者来帮助将Dify翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助,请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md)获取更多信息,并在我们的[Discord社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。 + +**Contributors** + + + + + ## 社区与支持 我们欢迎您为 Dify 做出贡献,以帮助改善 Dify。包括:提交代码、问题、新想法,或分享您基于 Dify 创建的有趣且有用的 AI 应用程序。同时,我们也欢迎您在不同的活动、会议和社交媒体上分享 Dify。 diff --git a/README_ES.md b/README_ES.md index 15caeabc2cd35..e3e8d4a45b1d8 100644 --- a/README_ES.md +++ b/README_ES.md @@ -1,119 +1,245 @@ -[![](./images/describe.png)](https://dify.ai) +![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +

- English | - 简体中文 | - 日本語 | - Español | - Klingon | - Français + Dify Cloud · + Auto-alojamiento · + Documentación · + Programar demostración

- Static Badge + Insignia Estática + + Insignia Estática - chat on Discord + chat en Discord - follow on Twitter + seguir en Twitter - Docker Pulls + Descargas de Docker + + Actividad de Commits el último mes + + Issues cerrados + + Publicaciones de discusión

- - Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs - + Actividad de Commits el último mes + Actividad de Commits el último mes + Actividad de Commits el último mes + Actividad de Commits el último mes + Actividad de Commits el último mes + Actividad de Commits el último mes

-**Dify** es una plataforma de desarrollo de aplicaciones para modelos de lenguaje de gran tamaño (LLM) que ya ha visto la creación de más de **100,000** aplicaciones basadas en Dify.AI. Integra los conceptos de Backend como Servicio y LLMOps, cubriendo el conjunto de tecnologías esenciales requerido para construir aplicaciones nativas de inteligencia artificial generativa, incluyendo un motor RAG incorporado. Con Dify, **puedes auto-desplegar capacidades similares a las de Assistants API y GPTs basadas en cualquier LLM.** - -![](./images/demo.png) - -## Utilizar Servicios en la Nube - -Usar [Dify.AI Cloud](https://dify.ai) proporciona todas las capacidades de la versión de código abierto, e incluye un complemento de 200 créditos de prueba para GPT. - -## Por qué Dify - -Dify se caracteriza por su neutralidad de modelo y es un conjunto tecnológico completo e ingenierizado, en comparación con las bibliotecas de desarrollo codificadas como LangChain. A diferencia de la API de Assistants de OpenAI, Dify permite el despliegue local completo de los servicios. +# -| Característica | Dify.AI | API de Assistants | LangChain | -|----------------|---------|------------------|-----------| -| **Enfoque de Programación** | Orientado a API | Orientado a API | Orientado a Código en Python | -| **Estrategia del Ecosistema** | Código Abierto | Cerrado y Comercial | Código Abierto | -| **Motor RAG** | Soportado | Soportado | No Soportado | -| **IDE de Prompts** | Incluido | Incluido | Ninguno | -| **LLMs Soportados** | Gran Variedad | Solo GPT | Gran Variedad | -| **Despliegue Local** | Soportado | No Soportado | No Aplicable | - -## Características - -![](./images/models.png) - -**1. Soporte LLM**: Integración con la familia de modelos GPT de OpenAI, o los modelos de la familia Llama2 de código abierto. De hecho, Dify soporta modelos comerciales convencionales y modelos de código abierto (desplegados localmente o basados en MaaS). - -**2. IDE de Prompts**: Orquestación visual de aplicaciones y servicios basados en LLMs con tu equipo. +

+ langgenius%2Fdify | Trendshift +

+Dify es una plataforma de desarrollo de aplicaciones de LLM de código abierto. Su interfaz intuitiva combina flujo de trabajo de IA, pipeline RAG, capacidades de agente, gestión de modelos, características de observabilidad y más, lo que le permite pasar rápidamente de un prototipo a producción. Aquí hay una lista de las características principales: +

+ +**1. Flujo de trabajo**: + Construye y prueba potentes flujos de trabajo de IA en un lienzo visual, aprovechando todas las siguientes características y más. + + + https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa + + + +**2. Soporte de modelos completo**: + Integración perfecta con cientos de LLMs propietarios / de código abierto de docenas de proveedores de inferencia y soluciones auto-alojadas, que cubren GPT, Mistral, Llama2 y cualquier modelo compatible con la API de OpenAI. Se puede encontrar una lista completa de proveedores de modelos admitidos [aquí](https://docs.dify.ai/getting-started/readme/model-providers). + +![proveedores-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) + + +**3. IDE de prompt**: + Interfaz intuitiva para crear prompts, comparar el rendimiento del modelo y agregar características adicionales como texto a voz a una aplicación basada en chat. + +**4. Pipeline RAG**: + Amplias capacidades de RAG que cubren todo, desde la ingestión de documentos hasta la recuperación, con soporte listo para usar para la extracción de texto de PDF, PPT y otros formatos de documento comunes. + +**5. Capacidades de agente**: + Puedes definir agent + +es basados en LLM Function Calling o ReAct, y agregar herramientas preconstruidas o personalizadas para el agente. Dify proporciona más de 50 herramientas integradas para agentes de IA, como Búsqueda de Google, DELL·E, Difusión Estable y WolframAlpha. + +**6. LLMOps**: + Supervisa y analiza registros de aplicaciones y rendimiento a lo largo del tiempo. Podrías mejorar continuamente prompts, conjuntos de datos y modelos basados en datos de producción y anotaciones. + +**7. Backend como servicio**: + Todas las ofertas de Dify vienen con APIs correspondientes, por lo que podrías integrar Dify sin esfuerzo en tu propia lógica empresarial. + + +## Comparación de características + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
CaracterísticaDify.AILangChainFlowiseAPI de Asistentes de OpenAI
Enfoque de programaciónAPI + orientado a la aplicaciónCódigo PythonOrientado a la aplicaciónOrientado a la API
LLMs admitidosGran variedadGran variedadGran variedadSolo OpenAI
Motor RAG
Agente
Flujo de trabajo
Observabilidad
Característica empresarial (SSO/Control de acceso)
Implementación local
+ +## Usando Dify + +- **Nube
** +Hospedamos un servicio [Dify Cloud](https://dify.ai) para que cualquiera lo pruebe sin configuración. Proporciona todas las capacidades de la versión autoimplementada e incluye 200 llamadas gratuitas a GPT-4 en el plan sandbox. + +- **Auto-alojamiento de Dify Community Edition
** +Pon rápidamente Dify en funcionamiento en tu entorno con esta [guía de inicio rápido](#quick-start). +Usa nuestra [documentación](https://docs.dify.ai) para más referencias e instrucciones más detalladas. + +- **Dify para Empresas / Organizaciones
** +Proporcionamos características adicionales centradas en la empresa. [Programa una reunión con nosotros](https://cal.com/guchenhe/30min) o [envíanos un correo electrónico](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) para discutir las necesidades empresariales.
+ > Para startups y pequeñas empresas que utilizan AWS, echa un vistazo a [Dify Premium en AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) e impleméntalo en tu propio VPC de AWS con un clic. Es una AMI asequible que ofrece la opción de crear aplicaciones con logotipo y marca personalizados. + + +## Manteniéndote al tanto + +Dale estrella a Dify en GitHub y serás notificado instantáneamente de las nuevas versiones. + +![danos estrella](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) + + + +## Inicio Rápido +> Antes de instalar Dify, asegúrate de que tu máquina cumpla con los siguientes requisitos mínimos del sistema: +> +>- CPU >= 2 núcleos +>- RAM >= 4GB + +
+ +La forma más fácil de iniciar el servidor de Dify es ejecutar nuestro archivo [docker-compose.yml](docker/docker-compose.yaml). Antes de ejecutar el comando de instalación, asegúrate de que [Docker](https://docs.docker.com/get-docker/) y [Docker Compose](https://docs.docker.com/compose/install/) estén instalados en tu máquina: -**3. Motor RAG**: Incluye varias capacidades RAG basadas en indexación de texto completo o incrustaciones de base de datos vectoriales, permitiendo la carga directa de PDFs, TXTs y otros formatos de texto. +```bash +cd docker +docker compose up -d +``` -**4. Agente de IA**: Basado en la llamada de funciones y ReAct, el marco de inferencia del Agente permite a los usuarios personalizar las herramientas, lo que ves es lo que obtienes. Dify proporciona más de una docena de capacidades de llamada de herramientas incorporadas, como Búsqueda de Google, DELL·E, Difusión Estable, WolframAlpha, etc. +Después de ejecutarlo, puedes acceder al panel de control de Dify en tu navegador en [http://localhost/install](http://localhost/install) y comenzar el proceso de inicialización. -**5. Operaciones Continuas**: Monitorear y analizar registros de aplicaciones y rendimiento, mejorando continuamente Prompts, conjuntos de datos o modelos usando datos de producción. +> Si deseas contribuir a Dify o realizar desarrollo adicional, consulta nuestra [guía para implementar desde el código fuente](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code) -## Antes de Empezar +## Próximos pasos -**¡Danos una estrella, y recibirás notificaciones instantáneas de todos los nuevos lanzamientos en GitHub!** +Si necesitas personalizar la configuración, consulta los comentarios en nuestro archivo [docker-compose.yml](docker/docker-compose.yaml) y configura manualmente la configuración del entorno -![star-us](https://github.com/langgenius/dify/assets/100913391/95f37259-7370-4456-a9f0-0bc01ef8642f) +. Después de realizar los cambios, ejecuta `docker-compose up -d` nuevamente. Puedes ver la lista completa de variables de entorno [aquí](https://docs.dify.ai/getting-started/install-self-hosted/environments). -- [Sitio web](https://dify.ai) -- [Documentación](https://docs.dify.ai) -- [Documentación de Implementación](https://docs.dify.ai/getting-started/install-self-hosted) -- [Preguntas Frecuentes](https://docs.dify.ai/getting-started/faq) +Si deseas configurar una instalación altamente disponible, hay [Gráficos Helm](https://helm.sh/) contribuidos por la comunidad que permiten implementar Dify en Kubernetes. -## Instalar la Edición Comunitaria +- [Gráfico Helm por @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) +- [Gráfico Helm por @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) -### Requisitos del Sistema -Antes de instalar Dify, asegúrate de que tu máquina cumpla con los siguientes requisitos mínimos del sistema: +## Contribuir -- CPU >= 2 núcleos -- RAM >= 4GB +Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +Al mismo tiempo, considera apoyar a Dify compartiéndolo en redes sociales y en eventos y conferencias. -### Inicio Rápido -La forma más sencilla de iniciar el servidor de Dify es ejecutar nuestro archivo [docker-compose.yml](docker/docker-compose.yaml). Antes de ejecutar el comando de instalación, asegúrate de que [Docker](https://docs.docker.com/get-docker/) y [Docker Compose](https://docs.docker.com/compose/install/) estén instalados en tu máquina: +> Estamos buscando colaboradores para ayudar con la traducción de Dify a idiomas que no sean el mandarín o el inglés. Si estás interesado en ayudar, consulta el [README de i18n](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) para obtener más información y déjanos un comentario en el canal `global-users` de nuestro [Servidor de Comunidad en Discord](https://discord.gg/8Tpq4AcN9c). -```bash -cd docker -docker compose up -d -``` +**Contribuidores** -Después de ejecutarlo, puedes acceder al panel de control de Dify en tu navegador en [http://localhost/install](http://localhost/install) y comenzar el proceso de instalación de inicialización. + + + -### Gráfico Helm +## Comunidad y Contacto -Un gran agradecimiento a @BorisPolonsky por proporcionarnos una versión del [Gráfico Helm](https://helm.sh/), que permite implementar Dify en Kubernetes. Puedes visitar https://github.com/BorisPolonsky/dify-helm para obtener información sobre la implementación. +* [Discusión en GitHub](https://github.com/langgenius/dify/discussions). Lo mejor para: compartir comentarios y hacer preguntas. +* [Reporte de problemas en GitHub](https://github.com/langgenius/dify/issues). Lo mejor para: errores que encuentres usando Dify.AI y propuestas de características. Consulta nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +* [Correo electrónico](mailto:support@dify.ai?subject=[GitHub]Questions%20About%20Dify). Lo mejor para: preguntas que tengas sobre el uso de Dify.AI. +* [Discord](https://discord.gg/FngNHpbcY7). Lo mejor para: compartir tus aplicaciones y pasar el rato con la comunidad. +* [Twitter](https://twitter.com/dify_ai). Lo mejor para: compartir tus aplicaciones y pasar el rato con la comunidad. -### Configuración +O, programa una reunión directamente con un miembro del equipo: -Si necesitas personalizar la configuración, consulta los comentarios en nuestro archivo [docker-compose.yml](docker/docker-compose.yaml) y configura manualmente la configuración del entorno. Después de realizar los cambios, ejecuta nuevamente `docker-compose up -d`. Puedes ver la lista completa de variables de entorno en nuestra [documentación](https://docs.dify.ai/getting-started/install-self-hosted/environments). + + + + + + + + + + + + + +
Punto de ContactoPropósito
Git-Hub-README-Button-3xConsultas comerciales y retroalimentación del producto
Git-Hub-README-Button-2xContribuciones, problemas y solicitudes de características
## Historial de Estrellas [![Gráfico de Historial de Estrellas](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) -## Comunidad y Soporte - -Te damos la bienvenida a contribuir a Dify para ayudar a hacer que Dify sea mejor de diversas maneras, enviando código, informando problemas, proponiendo nuevas ideas o compartiendo las aplicaciones de inteligencia artificial interesantes y útiles que hayas creado basadas en Dify. Al mismo tiempo, también te invitamos a compartir Dify en diferentes eventos, conferencias y redes sociales. - -- [Problemas en GitHub](https://github.com/langgenius/dify/issues). Lo mejor para: errores y problemas que encuentres al usar Dify.AI, consulta la [Guía de Contribución](CONTRIBUTING.md). -- [Soporte por Correo Electrónico](mailto:hello@dify.ai?subject=[GitHub]Preguntas%20sobre%20Dify). Lo mejor para: preguntas que tengas sobre el uso de Dify.AI. -- [Discord](https://discord.gg/FngNHpbcY7). Lo mejor para: compartir tus aplicaciones y socializar con la comunidad. -- [Twitter](https://twitter.com/dify_ai). Lo mejor para: compartir tus aplicaciones y socializar con la comunidad. -- [Licencia Comercial](mailto:business@dify.ai?subject=[GitHub]Consulta%20de%20Licencia%20Comercial). Lo mejor para: consultas comerciales sobre la licencia de Dify.AI para uso comercial. ## Divulgación de Seguridad @@ -121,4 +247,4 @@ Para proteger tu privacidad, evita publicar problemas de seguridad en GitHub. En ## Licencia -Este repositorio está disponible bajo la [Licencia de Código Abierto Dify](LICENSE), que es esencialmente Apache 2.0 con algunas restricciones adicionales. +Este repositorio está disponible bajo la [Licencia de Código Abierto de Dify](LICENSE), que es esencialmente Apache 2.0 con algunas restricciones adicionales. \ No newline at end of file diff --git a/README_FR.md b/README_FR.md index cada725d5aa7c..461408605a1bc 100644 --- a/README_FR.md +++ b/README_FR.md @@ -1,127 +1,250 @@ -[![](./images/describe.png)](https://dify.ai) +![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +

- English | - 简体中文 | - 日本語 | - Español | - Klingon | - Français + Dify Cloud · + Auto-hébergement · + Documentation · + Planifier une démo

- Static Badge + Badge statique + + Badge statique - chat on Discord + chat sur Discord - follow on Twitter + suivre sur Twitter - Docker Pulls + Tirages Docker + + Commits le mois dernier + + Problèmes fermés + + Messages de discussion

- - Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs - + Commits le mois dernier + Commits le mois dernier + Commits le mois dernier + Commits le mois dernier + Commits le mois dernier + Commits le mois dernier

+# -**Dify** est une plateforme de développement d'applications LLM qui a déjà vu plus de **100,000** applications construites sur Dify.AI. Elle intègre les concepts de Backend as a Service et LLMOps, couvrant la pile technologique de base requise pour construire des applications natives d'IA générative, y compris un moteur RAG intégré. Avec Dify, **vous pouvez auto-déployer des capacités similaires aux API Assistants et GPT basées sur n'importe quels LLM.** - -![](./images/demo.png) - -## Utiliser les services cloud - -L'utilisation de [Dify.AI Cloud](https://dify.ai) fournit toutes les capacités de la version open source, et comprend un essai gratuit de 200 crédits GPT. - -## Pourquoi Dify - -Dify présente une neutralité de modèle et est une pile technologique complète et conçue par rapport à des bibliothèques de développement codées en dur comme LangChain. Contrairement à l'API Assistants d'OpenAI, Dify permet un déploiement local complet des services. - -| Fonctionnalité | Dify.AI | API Assistants | LangChain | -|---------------|----------|-----------------|------------| -| **Approche de programmation** | Orientée API | Orientée API | Orientée code Python | -| **Stratégie écosystème** | Open source | Fermé et commercial | Open source | -| **Moteur RAG** | Pris en charge | Pris en charge | Non pris en charge | -| **IDE d'invite** | Inclus | Inclus | Aucun | -| **LLM pris en charge** | Grande variété | Seulement GPT | Grande variété | -| **Déploiement local** | Pris en charge | Non pris en charge | Non applicable | - - ## Fonctionnalités - -![](./images/models.png) - -**1\. Support LLM**: Intégration avec la famille de modèles GPT d'OpenAI, ou les modèles de la famille open source Llama2. En fait, Dify prend en charge les modèles commerciaux grand public et les modèles open source (déployés localement ou basés sur MaaS). - -**2\. IDE d'invite**: Orchestration visuelle d'applications et de services basés sur LLMs avec votre équipe. - -**3\. Moteur RAG**: Comprend diverses capacités RAG basées sur l'indexation de texte intégral ou les embeddings de base de données vectorielles, permettant le chargement direct de PDF, TXT et autres formats de texte. - -**4\. AI Agent**: Basé sur l'appel de fonction et ReAct, le framework d'inférence de l'Agent permet aux utilisateurs de personnaliser les outils, ce que vous voyez est ce que vous obtenez. Dify propose plus d'une douzaine de capacités d'appel d'outils intégrées, telles que la recherche Google, DELL·E, Diffusion Stable, WolframAlpha, etc. - -**5\. Opérations continues**: Surveillez et analysez les journaux et les performances des applications, améliorez en continu les invites, les datasets ou les modèles à l'aide de données de production. +

+ langgenius%2Fdify | Trendshift +

+Dify est une plateforme de développement d'applications LLM open source. Son interface intuitive combine un flux de travail d'IA, un pipeline RAG, des capacités d'agent, une gestion de modèles, des fonctionnalités d'observabilité, et plus encore, vous permettant de passer rapidement du prototype à la production. Voici une liste des fonctionnalités principales: +

+ +**1. Flux de travail**: + Construisez et testez des flux de travail d'IA puissants sur un canevas visuel, en utilisant toutes les fonctionnalités suivantes et plus encore. + + + https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa + + + +**2. Prise en charge complète des modèles**: + Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama2, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers). + +![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) + + +**3. IDE de prompt**: + Interface intuitive pour créer des prompts, comparer les performances des modèles et ajouter des fonctionnalités supplémentaires telles que la synthèse vocale à une application basée sur des chats. + +**4. Pipeline RAG**: + Des capacités RAG étendues qui couvrent tout, de l'ingestion de documents à la récupération, avec un support prêt à l'emploi pour l'extraction de texte à partir de PDF, PPT et autres formats de document courants. + +**5. Capac + +ités d'agent**: + Vous pouvez définir des agents basés sur l'appel de fonction LLM ou ReAct, et ajouter des outils pré-construits ou personnalisés pour l'agent. Dify fournit plus de 50 outils intégrés pour les agents d'IA, tels que la recherche Google, DELL·E, Stable Diffusion et WolframAlpha. + +**6. LLMOps**: + Surveillez et analysez les journaux d'application et les performances au fil du temps. Vous pouvez continuellement améliorer les prompts, les ensembles de données et les modèles en fonction des données de production et des annotations. + +**7. Backend-as-a-Service**: + Toutes les offres de Dify sont accompagnées d'API correspondantes, vous permettant d'intégrer facilement Dify dans votre propre logique métier. + + +## Comparaison des fonctionnalités + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FonctionnalitéDify.AILangChainFlowiseOpenAI Assistants API
Approche de programmationAPI + ApplicationCode PythonApplicationAPI
LLMs pris en chargeGrande variétéGrande variétéGrande variétéUniquement OpenAI
Moteur RAG
Agent
Flux de travail
Observabilité
Fonctionnalité d'entreprise (SSO/Contrôle d'accès)
Déploiement local
+ +## Utiliser Dify + +- **Cloud
** +Nous hébergeons un service [Dify Cloud](https://dify.ai) pour que tout le monde puisse l'essayer sans aucune configuration. Il fournit toutes les capacités de la version auto-hébergée et comprend 200 appels GPT-4 gratuits dans le plan bac à sable. + +- **Auto-hébergement Dify Community Edition
** +Lancez rapidement Dify dans votre environnement avec ce [guide de démarrage](#quick-start). +Utilisez notre [documentation](https://docs.dify.ai) pour plus de références et des instructions plus détaillées. + +- **Dify pour les entreprises / organisations
** +Nous proposons des fonctionnalités supplémentaires adaptées aux entreprises. [Planifiez une réunion avec nous](https://cal.com/guchenhe/30min) ou [envoyez-nous un e-mail](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) pour discuter des besoins de l'entreprise.
+ > Pour les startups et les petites entreprises utilisant AWS, consultez [Dify Premium sur AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) et déployez-le dans votre propre VPC AWS en un clic. C'est une offre AMI abordable avec la possibilité de créer des applications avec un logo et une marque personnalisés. + + +## Rester en avance + +Mettez une étoile à Dify sur GitHub et soyez instantanément informé des nouvelles versions. + +![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) + + + +## Démarrage rapide +> Avant d'installer Dify, assurez-vous que votre machine répond aux exigences système minimales suivantes: +> +>- CPU >= 2 cœurs +>- RAM >= 4 Go + +
+ +La manière la plus simple de démarrer le serveur Dify est d'exécuter notre fichier [docker-compose.yml](docker/docker-compose.yaml). Avant d'exécuter la commande d'installation, assurez-vous que [Docker](https://docs.docker.com/get-docker/) et [Docker Compose](https://docs.docker.com/compose/install/) sont installés sur votre machine: -## Avant de commencer +```bash +cd docker +docker compose up -d +``` -**Étoilez-nous, et vous recevrez des notifications instantanées pour toutes les nouvelles sorties sur GitHub !** -![star-us](https://github.com/langgenius/dify/assets/100913391/95f37259-7370-4456-a9f0-0bc01ef8642f) +Après l'exécution, vous pouvez accéder au tableau de bord Dify dans votre navigateur à [http://localhost/install](http://localhost/install) et commencer le processus d'initialisation. -- [Site web](https://dify.ai) -- [Documentation](https://docs.dify.ai) -- [Documentation de déploiement](https://docs.dify.ai/getting-started/install-self-hosted) -- [FAQ](https://docs.dify.ai/getting-started/faq) +> Si vous souhaitez contribuer à Dify ou effectuer un développement supplémentaire, consultez notre [guide de déploiement à partir du code source](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code) +## Prochaines étapes -## Installer la version Communauté +Si vous devez personnaliser la configuration, veuillez -### Configuration système + vous référer aux commentaires dans notre fichier [docker-compose.yml](docker/docker-compose.yaml) et définir manuellement la configuration de l'environnement. Après avoir apporté les modifications, veuillez exécuter à nouveau `docker-compose up -d`. Vous pouvez voir la liste complète des variables d'environnement [ici](https://docs.dify.ai/getting-started/install-self-hosted/environments). -Avant d'installer Dify, assurez-vous que votre machine répond aux exigences minimales suivantes: +Si vous souhaitez configurer une installation hautement disponible, il existe des [Helm Charts](https://helm.sh/) contribués par la communauté qui permettent de déployer Dify sur Kubernetes. -- CPU >= 2 cœurs -- RAM >= 4 Go +- [Helm Chart par @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) +- [Helm Chart par @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) -### Démarrage rapide -La façon la plus simple de démarrer le serveur Dify est d'exécuter notre fichier [docker-compose.yml](docker/docker-compose.yaml). Avant d'exécuter la commande d'installation, assurez-vous que [Docker](https://docs.docker.com/get-docker/) et [Docker Compose](https://docs.docker.com/compose/install/) sont installés sur votre machine: +## Contribuer -```bash -cd docker -docker compose up -d -``` +Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +Dans le même temps, veuillez envisager de soutenir Dify en le partageant sur les réseaux sociaux et lors d'événements et de conférences. -Après l'exécution, vous pouvez accéder au tableau de bord Dify dans votre navigateur à l'adresse [http://localhost/install](http://localhost/install) et démarrer le processus d'installation initiale. -### Chart Helm +> Nous recherchons des contributeurs pour aider à traduire Dify dans des langues autres que le mandarin ou l'anglais. Si vous êtes intéressé à aider, veuillez consulter le [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) pour plus d'informations, et laissez-nous un commentaire dans le canal `global-users` de notre [Serveur communautaire Discord](https://discord.gg/8Tpq4AcN9c). -Un grand merci à @BorisPolonsky pour nous avoir fourni une version [Helm Chart](https://helm.sh/) qui permet le déploiement de Dify sur Kubernetes. -Vous pouvez accéder à https://github.com/BorisPolonsky/dify-helm pour des informations de déploiement. +**Contributeurs** -### Configuration + + + -Si vous avez besoin de personnaliser la configuration, veuillez vous référer aux commentaires de notre fichier [docker-compose.yml](docker/docker-compose.yaml) et définir manuellement la configuration de l'environnement. Après avoir apporté les modifications, veuillez exécuter à nouveau `docker-compose up -d`. Vous trouverez la liste complète des variables d'environnement dans notre [documentation](https://docs.dify.ai/getting-started/install-self-hosted/environments). +## Communauté & Contact -## Historique d'étoiles +* [Discussion GitHub](https://github.com/langgenius/dify/discussions). Meilleur pour: partager des commentaires et poser des questions. +* [Problèmes GitHub](https://github.com/langgenius/dify/issues). Meilleur pour: les bogues que vous rencontrez en utilisant Dify.AI et les propositions de fonctionnalités. Consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +* [E-mail](mailto:support@dify.ai?subject=[GitHub]Questions%20About%20Dify). Meilleur pour: les questions que vous avez sur l'utilisation de Dify.AI. +* [Discord](https://discord.gg/FngNHpbcY7). Meilleur pour: partager vos applications et passer du temps avec la communauté. +* [Twitter](https://twitter.com/dify_ai). Meilleur pour: partager vos applications et passer du temps avec la communauté. -[![Diagramme de l'historique des étoiles](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) +Ou, planifiez directement une réunion avec un membre de l'équipe: + + + + + + + + + + + + + +
Point de contactObjectif
Git-Hub-README-Button-3xDemandes commerciales & retours produit
Git-Hub-README-Button-2xContributions, problèmes & demandes de fonctionnalités
-## Communauté & Support +## Historique des étoiles -Nous vous invitons à contribuer à Dify pour aider à améliorer Dify de diverses manières, en soumettant du code, des problèmes, de nouvelles idées ou en partageant les applications d'IA intéressantes et utiles que vous avez créées sur la base de Dify. En même temps, nous vous invitons également à partager Dify lors de différents événements, conférences et réseaux sociaux. +[![Graphique de l'historique des étoiles](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) -- [Problèmes GitHub](https://github.com/langgenius/dify/issues). Idéal pour : les bogues et les erreurs que vous rencontrez en utilisant Dify.AI, voir le [Guide de contribution](CONTRIBUTING.md). -- [Support par courriel](mailto:hello@dify.ai?subject=[GitHub]Questions%20About%20Dify). Idéal pour : les questions que vous avez au sujet de l'utilisation de Dify.AI. -- [Discord](https://discord.gg/FngNHpbcY7). Idéal pour : partager vos applications et discuter avec la communauté. -- [Twitter](https://twitter.com/dify_ai). Idéal pour : partager vos applications et discuter avec la communauté. -- [Licence commerciale](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry). Idéal pour : les demandes commerciales de licence de Dify.AI pour un usage commercial. -## Divulgation de la sécurité +## Divulgation de sécurité -Pour protéger votre vie privée, veuillez éviter de publier des problèmes de sécurité sur GitHub. Envoyez plutôt vos questions à security@dify.ai et nous vous fournirons une réponse plus détaillée. +Pour protéger votre vie privée, veuillez éviter de publier des problèmes de sécurité sur GitHub. Au lieu de cela, envoyez vos questions à security@dify.ai et nous vous fournirons une réponse plus détaillée. -## Licence +## Licence -Ce référentiel est disponible sous la [Licence open source Dify](LICENSE), qui est essentiellement Apache 2.0 avec quelques restrictions supplémentaires. +Ce référentiel est disponible sous la [Licence open source Dify](LICENSE), qui est essentiellement l'Apache 2.0 avec quelques restrictions supplémentaires. diff --git a/README_JA.md b/README_JA.md index f3655b4767bad..5e1bf79e29a2a 100644 --- a/README_JA.md +++ b/README_JA.md @@ -1,130 +1,249 @@ -[![](./images/describe.png)](https://dify.ai) +![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +

- English | - 简体中文 | - 日本語 | - Español | - Klingon | - Français + Dify Cloud · + 自己ホスティング · + ドキュメント · + デモのスケジュール

- Static Badge + Static Badge + + Static Badge - chat on Discord + Discordでチャット - follow on Twitter + Twitterでフォロー - Docker Pulls + Docker Pulls + + 先月のコミット + + クローズされた問題 + + ディスカッション投稿

- - Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs - + 先月のコミット + 先月のコミット + 先月のコミット + 先月のコミット + 先月のコミット + 先月のコミット

+# -"Difyは、既にDify.AI上で10万以上のアプリケーションが構築されているLLMアプリケーション開発プラットフォームです。バックエンド・アズ・ア・サービスとLLMOpsの概念を統合し、組み込みのRAGエンジンを含む、生成AIネイティブアプリケーションを構築するためのコアテックスタックをカバーしています。Difyを使用すると、どのLLMに基づいても、Assistants APIやGPTのような機能を自己デプロイすることができます。" - -Please note that translating complex technical terms can sometimes result in slight variations in meaning due to differences in language nuances. - -![](./images/demo.png) - -## クラウドサービスの利用 - -[Dify.AI Cloud](https://dify.ai) を使用すると、オープンソース版の全機能を利用でき、さらに200GPTのトライアルクレジットが無料で提供されます。 - -## Difyの利点 - -Difyはモデルニュートラルであり、LangChainのようなハードコードされた開発ライブラリと比較して、完全にエンジニアリングされた技術スタックを特徴としています。OpenAIのAssistants APIとは異なり、Difyではサービスの完全なローカルデプロイメントが可能です。 - -| 機能 | Dify.AI | Assistants API | LangChain | -|---------|---------|----------------|-----------| -| **プログラミングアプローチ** | API指向 | API指向 | Pythonコード指向 | -| **エコシステム戦略** | オープンソース | 閉鎖的かつ商業的 | オープンソース | -| **RAGエンジン** | サポート済み | サポート済み | 非サポート | -| **プロンプトIDE** | 含まれる | 含まれる | なし | -| **サポートされるLLMs** | 豊富な種類 | GPTのみ | 豊富な種類 | -| **ローカルデプロイメント** | サポート済み | 非サポート | 該当なし | - - ## 機能 - -![](./images/models.png) - -**1\. LLMサポート**: OpenAIのGPTファミリーモデルやLlama2ファミリーのオープンソースモデルとの統合。 実際、Difyは主要な商用モデルとオープンソースモデル(ローカルでデプロイまたはMaaSベース)をサポートしています。 - -**2\. プロンプトIDE**: チームとのLLMベースのアプリケーションとサービスの視覚的なオーケストレーション。 - -**3\. RAGエンジン**: フルテキストインデックスまたはベクトルデータベース埋め込みに基づくさまざまなRAG機能を含み、PDF、TXT、その他のテキストフォーマットの直接アップロードを可能にします。 - -**4. AIエージェント**: 関数呼び出しとReActに基づくAgent推論フレームワークにより、ユーザーはツールをカスタマイズすることができます。Difyは、Google検索、DELL·E、Stable Diffusion、WolframAlphaなど、十数種類の組み込みツール呼び出し機能を提供しています。 - -**5\. 継続的運用**: アプリケーションログとパフォーマンスを監視および分析し、運用データを使用してプロンプト、データセット、またはモデルを継続的に改善します。 - -## 開始する前に - -**私たちをスターして、GitHub上でのすべての新しいリリースに対する即時通知を受け取ります!** - -![私たちをスターして](https://github.com/langgenius/dify/assets/100913391/95f37259-7370-4456-a9f0-0bc01ef8642f) - -- [Website](https://dify.ai) -- [Docs](https://docs.dify.ai) -- [Deployment Docs](https://docs.dify.ai/getting-started/install-self-hosted) -- [FAQ](https://docs.dify.ai/getting-started/faq) +

+ langgenius%2Fdify | Trendshift +

+DifyはオープンソースのLLMアプリケーション開発プラットフォームです。直感的なインターフェースには、AIワークフロー、RAGパイプライン、エージェント機能、モデル管理、観測機能などが組み合わさっており、プロトタイプから本番までの移行を迅速に行うことができます。以下は、主要機能のリストです: +

+ +**1. ワークフロー**: + ビジュアルキャンバス上で強力なAIワークフローを構築してテストし、以下の機能を活用してプロトタイプを超えることができます。 + + + https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa + + + +**2. 網羅的なモデルサポート**: + 数百のプロプライエタリ/オープンソースのLLMと、数十の推論プロバイダーおよびセルフホスティングソリューションとのシームレスな統合を提供します。GPT、Mistral、Llama2、およびOpenAI API互換のモデルをカバーします。サポートされているモデルプロバイダーの完全なリストは[こちら](https://docs + +.dify.ai/getting-started/readme/model-providers)をご覧ください。 + +![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) + + +**3. プロンプトIDE**: + チャットベースのアプリにテキスト読み上げなどの追加機能を追加するプロンプトを作成し、モデルのパフォーマンスを比較する直感的なインターフェース。 + +**4. RAGパイプライン**: + 文書の取り込みから取得までをカバーする幅広いRAG機能で、PDF、PPTなどの一般的なドキュメント形式からのテキスト抽出に対するアウトオブボックスのサポートを提供します。 + +**5. エージェント機能**: + LLM関数呼び出しまたはReActに基づいてエージェントを定義し、エージェント向けの事前構築済みまたはカスタムのツールを追加できます。Difyには、Google検索、DELL·E、Stable Diffusion、WolframAlphaなどのAIエージェント用の50以上の組み込みツールが用意されています。 + +**6. LLMOps**: + アプリケーションログとパフォーマンスを時間の経過とともにモニタリングおよび分析します。本番データと注釈に基づいて、プロンプト、データセット、およびモデルを継続的に改善できます。 + +**7. Backend-as-a-Service**: + Difyのすべての提供には、それに対応するAPIが付属しており、独自のビジネスロジックにDifyをシームレスに統合できます。 + + +## 機能比較 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
機能Dify.AILangChainFlowiseOpenAI Assistants API
プログラミングアプローチAPI + アプリ指向Pythonコードアプリ指向API指向
サポートされているLLM豊富なバリエーション豊富なバリエーション豊富なバリエーションOpenAIのみ
RAGエンジン
エージェント
ワークフロー
観測性
エンタープライズ機能(SSO/アクセス制御)
ローカル展開
+ +## Difyの使用方法 + +- **クラウド
** +[こちら](https://dify.ai)のDify Cloudサービスを利用して、セットアップが不要で誰でも試すことができます。サンドボックスプランでは、200回の無料のGPT-4呼び出しが含まれています。 + +- **Dify Community Editionのセルフホスティング
** +この[スターターガイド](#quick-start)を使用して、環境でDifyをすばやく実行できます。 +さらなる参照や詳細な手順については、[ドキュメント](https://docs.dify.ai)をご覧ください。 + +- **エンタープライズ/組織向けのDify
** +追加のエンタープライズ向け機能を提供しています。[こちらからミーティ + +ングを予約](https://cal.com/guchenhe/30min)したり、[メールを送信](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)してエンタープライズのニーズについて相談してください。
+ > AWSを使用しているスタートアップや中小企業の場合は、[AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6)のDify Premiumをチェックして、ワンクリックで独自のAWS VPCにデプロイできます。カスタムロゴとブランディングでアプリを作成するオプションを備えた手頃な価格のAMIオファリングです。 + + +## 先を見る + +GitHubでDifyにスターを付け、新しいリリースをすぐに通知されます。 + +![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) + + + +## クイックスタート +> Difyをインストールする前に、マシンが以下の最小システム要件を満たしていることを確認してください: +> +>- CPU >= 2コア +>- RAM >= 4GB + +
+ +Difyサーバーを起動する最も簡単な方法は、当社の[docker-compose.yml](docker/docker-compose.yaml)ファイルを実行することです。インストールコマンドを実行する前に、マシンに[Docker](https://docs.docker.com/get-docker/)と[Docker Compose](https://docs.docker.com/compose/install/)がインストールされていることを確認してください。 -## コミュニティエディションのインストール +```bash +cd docker +docker compose up -d +``` -### システム要件 +実行後、ブラウザで[http://localhost/install](http://localhost/install)にアクセスし、初期化プロセスを開始できます。 -Difyをインストールする前に、以下の最低限のシステム要件を満たしていることを確認してください: +> Difyに貢献したり、追加の開発を行う場合は、[ソースコードからのデプロイガイド](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code)を参照してください。 -- CPU >= 2コア -- RAM >= 4GB +## 次のステップ -### クイックスタート +環境設定をカスタマイズする場合は、[docker-compose.yml](docker/docker-compose.yaml)ファイル内のコメントを参照して、環境設定を手動で設定してください。変更を加えた後は、再び `docker-compose up -d` を実行してください。環境変数の完全なリストは[こちら](https://docs.dify.ai/getting-started/install-self-hosted/environments)をご覧ください。 -Difyサーバーを始める最も簡単な方法は、[docker-compose.yml](docker/docker-compose.yaml) ファイルを実行することです。インストールコマンドを実行する前に、マシンに [Docker](https://docs.docker.com/get-docker/) と [Docker Compose](https://docs.docker.com/compose/install/) がインストールされていることを確認してください: +高可用性のセットアップを構成する場合は、コミュニティによって提供されている[Helm Charts](https://helm.sh/)があり、これによりKubernetes上にDifyを展開できます。 -```bash -cd docker -docker compose up -d -``` +- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) +- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) -実行後、ブラウザで [http://localhost/install](http://localhost/install) にアクセスし、初期化インストールプロセスを開始できます。 -### Helm Chart +## 貢献 -@BorisPolonskyによる[Helm Chart](https://helm.sh/) バージョンを提供してくれて、大変感謝しています。これにより、DifyはKubernetes上にデプロイすることができます。 -デプロイ情報については、https://github.com/BorisPolonsky/dify-helm をご覧ください。 +コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)を参照してください。 +同時に、DifyをSNSやイベント、カンファレンスで共有してサポートしていただけると幸いです。 -### 設定 -設定をカスタマイズする必要がある場合は、[docker-compose.yml](docker/docker-compose.yaml) ファイルのコメントを参照し、環境設定を手動で行ってください。変更を行った後は、もう一度 `docker-compose up -d` を実行してください。環境変数の完全なリストは、[ドキュメント](https://docs.dify.ai/getting-started/install-self-hosted/environments)で確認できます。 +> Difyを英語または中国語以外の言語に翻訳してくれる貢献者を募集しています。興味がある場合は、詳細については[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md)を参照してください。また、[Discordコミュニティサーバー](https://discord.gg/8Tpq4AcN9c)の`global-users`チャンネルにコメントを残してください。 +**貢献者** -## スターヒストリー + + + -[![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) +## コミュニティ & お問い合わせ -## コミュニティとサポート +* [Github Discussion](https://github.com/langgenius/dify/discussions). 主に: フィードバックの共有や質問。 +* [GitHub Issues](https://github.com/langgenius/dify/issues). 主に: Dify.AIの使用中に遭遇したバグや機能提案。 +* [Email](mailto:support@dify.ai?subject=[GitHub]Questions%20About%20Dify). 主に: Dify.AIの使用に関する質問。 +* [Discord](https://discord.gg/FngNHpbcY7). 主に: アプリケーションの共有やコミュニティとの交流。 +* [Twitter](https://twitter.com/dify_ai). 主に: アプリケーションの共有やコミュニティとの交流。 -Difyに貢献していただき、コードの提出、問題の報告、新しいアイデアの提供、またはDifyを基に作成した興味深く有用なAIアプリケーションの共有により、Difyをより良いものにするお手伝いを歓迎します。同時に、さまざまなイベント、会議、ソーシャルメディアでDifyを共有することも歓迎します。 +または、直接チームメンバーとミーティングをスケジュールします: -- [GitHub Issues](https://github.com/langgenius/dify/issues)。最適な使用法:Dify.AIの使用中に遭遇するバグやエラー、[貢献ガイド](CONTRIBUTING.md)を参照。 -- [Email サポート](mailto:hello@dify.ai?subject=[GitHub]Questions%20About%20Dify)。最適な使用法:Dify.AIの使用に関する質問。 -- [Discord](https://discord.gg/FngNHpbcY7)。最適な使用法:アプリケーションの共有とコミュニティとの交流。 -- [Twitter](https://twitter.com/dify_ai)。最適な使用法:アプリケーションの共有とコミュニティとの交流。 -- [ビジネスライセンス](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)。最適な使用法:Dify.AIを商業利用するためのビジネス関連の問い合わせ。 + + + + + + + + + + + + + + + + + +
連絡先目的
ミーティング無料の30分間のミーティングをスケジュールしてください。
技術サポート技術的な問題やサポートに関する質問
営業担当法人ライセンスに関するお問い合わせ
-プライバシー保護のため、GitHub へのセキュリティ問題の投稿は避けてください。代わりに、あなたの質問を security@dify.ai に送ってください。より詳細な回答を提供します。 ## ライセンス - このリポジトリは、基本的にApache 2.0にいくつかの追加制限を加えた[Difyオープンソースライセンス](LICENSE)の下で利用できます。 +プロジェクトはMITライセンスの下で利用可能です。[LICENSE](LICENSE)をご参照ください。 diff --git a/README_KL.md b/README_KL.md index f26f88649f9c0..254016601e386 100644 --- a/README_KL.md +++ b/README_KL.md @@ -1,119 +1,250 @@ -[![](./images/describe.png)](https://dify.ai) +![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +

- English | - 简体中文 | - 日本語 | - Español | - Klingon | - Français + Dify Cloud · + Self-hosting · + Documentation · + Schedule demo

- Static Badge + Static Badge + + Static Badge - chat on Discord - follow on Twitter - Docker Pulls + Docker Pulls + + Commits last month + + Issues closed + + Discussion posts

-**Dify** Hoch LLM qorwI' pIqoDvam pagh laHta' je **100,000** pIqoDvamvam Dify.AI De'wI'. Dify leghpu' Backend chu' a Service teH LLMOps vItlhutlh, generative AI-native pIqoD teq wa'vam, vIyoD Built-in RAG engine. Dify, **'ej chenmoHmoH Hoch 'oHna' Assistant API 'ej GPTmey HoStaHbogh LLMmey.** - -![](./images/demo.png) - -## ngIl QaQ - -[Dify.AI ngIl](https://dify.ai) pIm neHlaH 'ej ghaH. cha'logh wa' DIvI' 200 GPT trial credits. - -## Dify WovmoH - -Dify Daq rIn neutrality 'ej Hoch, LangChain tInHar HubwI'. maH Daqbe'law' Qawqar, OpenAI's Assistant API Daq local neH deployment. - -| Qo'logh | Dify.AI | Assistants API | LangChain | -|---------|---------|----------------|-----------| -| **qet QaS** | API-oriented | API-oriented | Python Code-oriented | -| **Ecosystem Strategy** | Open Source | Closed and Commercial | Open Source | -| **RAG Engine** | Ha'qu' | Ha'qu' | ghoS Ha'qu' | -| **Prompt IDE** | jaH Include | jaH Include | qeylIS qaq | -| **qet LLMmey** | bo'Degh Hoch | GPTmey tIn | bo'Degh Hoch | -| **local deployment** | Ha'qu' | tInHa'qu' | tInHa'qu' ghogh | - -## ruch - -![](./images/models.png) - -**1. LLM tIq**: OpenAI's GPT Hur nISmoHvam neH vIngeH, wa' Llama2 Hur nISmoHvam. Heghlu'lu'pu' Dify mIw 'oH choH qay'be'.Daq commercial Hurmey 'ej Open Source Hurmey (maqtaHvIS pagh locally neH neH deployment HoSvam). - -**2. Prompt IDE**: cha'logh wa' LLMmey Hoch janlu'pu' 'ej lughpu' choH qay'be'. +

+ Commits last month + Commits last month + Commits last month + Commits last month + Commits last month + Commits last month +

-**3. RAG Engine**: RAG vaD tIqpu' lo'taH indexing qor neH vector database wa' embeddings wIj, PDFs, TXTs, 'ej ghojmoHmoH HIq qorlIj je upload. +# -**4. AI Agent**: Function Calling 'ej ReAct Daq Hurmey, Agent inference framework Hoch users customize tools, vaj 'oH QaQ. Dify Hoch loS ghaH 'ej wa'vatlh built-in tool calling capabilities, Google Search, DELL·E, Stable Diffusion, WolframAlpha, 'ej. +

+ langgenius%2Fdify | Trendshift +

+Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features: +

+ +**1. Workflow**: + Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond. + + + https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa + + + +**2. Comprehensive model support**: + Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama2, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers). + +![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) + + +**3. Prompt IDE**: + Intuitive interface for crafting prompts, comparing model performance, and adding additional features such as text-to-speech to a chat-based app. + +**4. RAG Pipeline**: + Extensive RAG capabilities that cover everything from document ingestion to retrieval, with out-of-box support for text extraction from PDFs, PPTs, and other common document formats. + +**5. Agent capabilities**: + You can define agents based on LLM Function Calling or ReAct, and add pre-built or custom tools for the agent. Dify provides 50+ built-in tools for AI agents, such as Google Search, DELL·E, Stable Diffusion and WolframAlpha. + +**6. LLMOps**: + Monitor and analyze application logs and performance over time. You could continuously improve prompts, datasets, and models based on production data and annotations. + +**7. Backend-as-a-Service**: + All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic. + + +## Feature Comparison + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FeatureDify.AILangChainFlowiseOpenAI Assistants API
Programming ApproachAPI + App-orientedPython CodeApp-orientedAPI-oriented
Supported LLMsRich VarietyRich VarietyRich VarietyOpenAI-only
RAG Engine
Agent
Workflow
Observability
Enterprise Feature (SSO/Access control)
Local Deployment
+ +## Using Dify + +- **Cloud
** +We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan. + +- **Self-hosting Dify Community Edition
** +Quickly get Dify running in your environment with this [starter guide](#quick-start). +Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions. + +- **Dify for Enterprise / Organizations
** +We provide additional enterprise-centric features. [Schedule a meeting with us](https://cal.com/guchenhe/30min) or [send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs.
+ > For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one-click. It's an affordable AMI offering with the option to create apps with custom logo and branding. + + +## Staying ahead + +Star Dify on GitHub and be instantly notified of new releases. + +![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) + + + +## Quick Start +> Before installing Dify, make sure your machine meets the following minimum system requirements: +> +>- CPU >= 2 Core +>- RAM >= 4GB + +
+ +The easiest way to start the Dify server is to run our [docker-compose.yml](docker/docker-compose.yaml) file. Before running the installation command, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: -**5. QaS muDHa'wI': cha'logh wa' pIq mI' logs 'ej quv yIn, vItlhutlh tIq 'e'wIj lo'taHmoHmoH Prompts, vItlhutlh, Hurmey ghaH production data jatlh. +```bash +cd docker +docker compose up -d +``` -## Do'wI' qabmey lo'taH +After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process. -**maHvaD jatlhchugh, GitHub Daq Hoch chu' ghompu'vam tIqel yInob!** +> If you'd like to contribute to Dify or do additional development, refer to our [guide to deploying from source code](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code) -![star-us](https://github.com/langgenius/dify/assets/100913391/95f37259-7370-4456-a9f0-0bc01ef8642f) +## Next steps -- [Website](https://dify.ai) -- [Docs](https://docs.dify.ai) -- [lo'taHmoH Docs](https://docs.dify.ai/getting-started/install-self-hosted) -- [FAQ](https://docs.dify.ai/getting-started/faq) +If you need to customize the configuration, please refer to the comments in our [docker-compose.yml](docker/docker-compose.yaml) file and manually set the environment configuration. After making the changes, please run `docker-compose up -d` again. You can see the full list of environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). -## Community Edition tu' yo' +If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) which allow Dify to be deployed on Kubernetes. -### System Qab +- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) +- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) -Dify yo' yo' qaqmeH SuS chenmoH 'oH qech! -- CPU >= 2 Cores -- RAM >= 4GB +## Contributing -### Quick Start +For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +At the same time, please consider supporting Dify by sharing it on social media and at events and conferences. -Dify server luHoHtaHlu' vIngeH lo'laHbe'chugh vIyoD [docker-compose.yml](docker/docker-compose.yaml) QorwI'ghach. toH yItlhutlh chenmoH luH!chugh 'ay' vaj vIneHmeH, 'ej [Docker](https://docs.docker.com/get-docker/) 'ej [Docker Compose](https://docs.docker.com/compose/install/) vaj 'oH 'e' vIneHmeH: -```bash -cd docker -docker compose up -d -``` +> We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c). -luHoHtaHmeH HoHtaHvIS, Dify dashboard vIneHmeH vIngeH lI'wI' [http://localhost/install](http://localhost/install) 'ej 'oH initialization 'e' vIneHmeH. +**Contributors** -### Helm Chart + + + -@BorisPolonsky Dify wIq tIq ['ay'var (Helm Chart)](https://helm.sh/) version Hur yIn chu' Dify luHoHchu'. Heghlu'lu' vIneHmeH [https://github.com/BorisPolonsky/dify-helm](https://github.com/BorisPolonsky/dify-helm) 'ej vaj QaS deployment information. +## Community & Contact -### veS config +* [Github Discussion](https://github.com/langgenius/dify/discussions -chenmoHDI' config lo'taH ghaH, vItlhutlh HIq wIgharghbe'lu'pu'. toH lo'taHvIS pagh vay' vIneHmeH, 'ej `docker-compose up -d` wa'DIch. tIqmoHmeH list full wa' lo'taHvo'lu'pu' ghaH [docs](https://docs.dify.ai/getting-started/install-self-hosted/environments). +). Best for: sharing feedback and asking questions. +* [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +* [Email](mailto:support@dify.ai?subject=[GitHub]Questions%20About%20Dify). Best for: questions you have about using Dify.AI. +* [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community. +* [Twitter](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community. -## tIng qem +Or, schedule a meeting directly with a team member: -[![tIng qem Hur Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) + + + + + + + + + + + + + +
Point of ContactPurpose
Git-Hub-README-Button-3xBusiness enquiries & product feedback
Git-Hub-README-Button-2xContributions, issues & feature requests
-## choHmoH 'ej vItlhutlh +## Star History -Dify choHmoH je mIw Dify puqloD, Dify ghaHta'bogh vItlhutlh, HurDI' code, ghItlh, ghItlh qo'lu'pu'pu' qej. tIqmeH, Hurmey je, Dify Hur tIqDI' woDDaj, DuD QangmeH 'ej HInobDaq vItlhutlh HImej Dify'e'. +[![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) -- [GitHub vItlhutlh](https://github.com/langgenius/dify/issues). Hurmey: bugs 'ej errors Dify.AI tIqmeH. yImej [Contribution Guide](CONTRIBUTING.md). -- [Email QaH](mailto:hello@dify.ai?subject=[GitHub]Questions%20About%20Dify). Hurmey: questions vItlhutlh Dify.AI chaw'. -- [Discord](https://discord.gg/FngNHpbcY7). Hurmey: jIpuv 'ej jImej mIw Dify vItlhutlh. -- [Twitter](https://twitter.com/dify_ai). Hurmey: jIpuv 'ej jImej mIw Dify vItlhutlh. -- [Business License](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry). Hurmey: qurgh vItlhutlh Hurmey Dify.AI tIqbe'law'. -## bIQDaqmey bom +## Security Disclosure -taghlI' vIngeH'a'? pong security 'oH posting GitHub. yItlhutlh, toH security@dify.ai 'ej vIngeH'a'. +To protect your privacy, please avoid posting security issues on GitHub. Instead, send your questions to security@dify.ai and we will provide you with a more detailed answer. ## License -ghItlh puqloD chenmoH [Dify vItlhutlh Hur](LICENSE), ghaH nIvbogh Apache 2.0. - +This repository is available under the [Dify Open Source License](LICENSE), which is essentially Apache 2.0 with a few additional restrictions. \ No newline at end of file diff --git a/api/.env.example b/api/.env.example index bbcb7cf1ec15d..481d6ab49929b 100644 --- a/api/.env.example +++ b/api/.env.example @@ -57,7 +57,7 @@ AZURE_BLOB_ACCOUNT_URL=https://.blob.core.windows.net WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* -# Vector database configuration, support: weaviate, qdrant, milvus +# Vector database configuration, support: weaviate, qdrant, milvus, relyt VECTOR_STORE=weaviate # Weaviate configuration @@ -78,6 +78,13 @@ MILVUS_USER=root MILVUS_PASSWORD=Milvus MILVUS_SECURE=false +# Relyt configuration +RELYT_HOST=127.0.0.1 +RELYT_PORT=5432 +RELYT_USER=postgres +RELYT_PASSWORD=postgres +RELYT_DATABASE=postgres + # Upload configuration UPLOAD_FILE_SIZE_LIMIT=15 UPLOAD_FILE_BATCH_LIMIT=5 @@ -149,3 +156,7 @@ TEMPLATE_TRANSFORM_MAX_LENGTH=80000 CODE_MAX_STRING_ARRAY_LENGTH=30 CODE_MAX_OBJECT_ARRAY_LENGTH=30 CODE_MAX_NUMBER_ARRAY_LENGTH=1000 + +# API Tool configuration +API_TOOL_DEFAULT_CONNECT_TIMEOUT=10 +API_TOOL_DEFAULT_READ_TIMEOUT=60 diff --git a/api/Dockerfile b/api/Dockerfile index 678416f11361c..96b230e173b42 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -11,7 +11,8 @@ RUN apt-get update \ COPY requirements.txt /requirements.txt -RUN pip install --prefix=/pkg -r requirements.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install --prefix=/pkg -r requirements.txt # production stage FROM base AS production diff --git a/api/README.md b/api/README.md index 93ee1816c649a..4069b3d88be3e 100644 --- a/api/README.md +++ b/api/README.md @@ -17,16 +17,16 @@ ```bash sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env ``` -3.5 If you use Anaconda, create a new environment and activate it +4. If you use Anaconda, create a new environment and activate it ```bash conda create --name dify python=3.10 conda activate dify ``` -4. Install dependencies +5. Install dependencies ```bash pip install -r requirements.txt ``` -5. Run migrate +6. Run migrate Before the first launch, migrate the database to the latest version. @@ -47,9 +47,11 @@ pip install -r requirements.txt --upgrade --force-reinstall ``` -6. Start backend: +7. Start backend: ```bash flask run --host 0.0.0.0 --port=5001 --debug ``` -7. Setup your application by visiting http://localhost:5001/console/api/setup or other apis... -8. If you need to debug local async processing, you can run `celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail`, celery can do dataset importing and other async tasks. +8. Setup your application by visiting http://localhost:5001/console/api/setup or other apis... +9. If you need to debug local async processing, please start the worker service by running +`celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail`. +The started celery app handles the async tasks, e.g. dataset importing and documents indexing. diff --git a/api/app.py b/api/app.py index aea28ac93a1cf..ad91b5636f789 100644 --- a/api/app.py +++ b/api/app.py @@ -1,16 +1,13 @@ import os -from werkzeug.exceptions import Unauthorized - if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true': from gevent import monkey + monkey.patch_all() # if os.environ.get("VECTOR_STORE") == 'milvus': import grpc.experimental.gevent - grpc.experimental.gevent.init_gevent() - import langchain - langchain.verbose = True + grpc.experimental.gevent.init_gevent() import json import logging @@ -21,6 +18,7 @@ from flask import Flask, Response, request from flask_cors import CORS +from werkzeug.exceptions import Unauthorized from commands import register_commands from config import CloudEditionConfig, Config from extensions import ( @@ -44,6 +42,7 @@ # DO NOT REMOVE BELOW from events import event_handlers from models import account, dataset, model, source, task, tool, tools, web + # DO NOT REMOVE ABOVE @@ -51,7 +50,7 @@ # fix windows platform if os.name == "nt": - os.system('tzutil /s "UTC"') + os.system('tzutil /s "UTC"') else: os.environ['TZ'] = 'UTC' time.tzset() @@ -60,6 +59,7 @@ class DifyApp(Flask): pass + # ------------- # Configuration # ------------- @@ -67,6 +67,7 @@ class DifyApp(Flask): config_type = os.getenv('EDITION', default='SELF_HOSTED') # ce edition first + # ---------------------------- # Application Factory Function # ---------------------------- @@ -192,7 +193,6 @@ def register_blueprints(app): app = create_app() celery = app.extensions["celery"] - if app.config['TESTING']: print("App is running in TESTING mode") diff --git a/api/commands.py b/api/commands.py index 75e66cc4d160b..a5944470660ee 100644 --- a/api/commands.py +++ b/api/commands.py @@ -15,36 +15,75 @@ from models.account import Tenant from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment from models.dataset import Document as DatasetDocument -from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation +from models.model import ( + Account, + App, + AppAnnotationSetting, + AppMode, + Conversation, + MessageAnnotation, +) from models.provider import Provider, ProviderModel +from services.account_service import RegisterService -@click.command('reset-password', help='Reset the account password.') -@click.option('--email', prompt=True, help='The email address of the account whose password you need to reset') -@click.option('--new-password', prompt=True, help='the new password.') -@click.option('--password-confirm', prompt=True, help='the new password confirm.') +@click.command("register", help="Register a new accout") +@click.option("--email", prompt=True, help="The email address of the account you want to register") +@click.option("--name", prompt=True, help="The name of the account you want to register") +@click.option("--password", prompt=True, help="The password of the account you want to register") +def register(email, name, password): + """ + Register a new accout + """ + try: + email_validate(email) + except: + click.echo(click.style("sorry. {} is not a valid email. ".format(email), fg="red")) + return + + try: + valid_password(password) + except: + click.echo(click.style("sorry. The passwords must match {} ".format(password_pattern), fg="red")) + return + + account = db.session.query(Account).filter(Account.email == email).one_or_none() + + if account: + click.echo(click.style("sorry. the account: [{}] already exists .".format(email), fg="red")) + return + + account = RegisterService.register(email, name, password) + click.echo(click.style("Congratulations!, account has been registered.", fg="green")) + + +@click.command("reset-password", help="Reset the account password.") +@click.option( + "--email", + prompt=True, + help="The email address of the account whose password you need to reset", +) +@click.option("--new-password", prompt=True, help="the new password.") +@click.option("--password-confirm", prompt=True, help="the new password confirm.") def reset_password(email, new_password, password_confirm): """ Reset password of owner account Only available in SELF_HOSTED mode """ if str(new_password).strip() != str(password_confirm).strip(): - click.echo(click.style('sorry. The two passwords do not match.', fg='red')) + click.echo(click.style("sorry. The two passwords do not match.", fg="red")) return - account = db.session.query(Account). \ - filter(Account.email == email). \ - one_or_none() + account = db.session.query(Account).filter(Account.email == email).one_or_none() if not account: - click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red')) + click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red")) return try: valid_password(new_password) except: - click.echo( - click.style('sorry. The passwords must match {} '.format(password_pattern), fg='red')) + click.echo(click.style("sorry. The passwords must match {} ".format(password_pattern), fg="red")) return # generate password salt @@ -57,80 +96,102 @@ def reset_password(email, new_password, password_confirm): account.password = base64_password_hashed account.password_salt = base64_salt db.session.commit() - click.echo(click.style('Congratulations!, password has been reset.', fg='green')) + click.echo(click.style("Congratulations!, password has been reset.", fg="green")) -@click.command('reset-email', help='Reset the account email.') -@click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset') -@click.option('--new-email', prompt=True, help='the new email.') -@click.option('--email-confirm', prompt=True, help='the new email confirm.') +@click.command("reset-email", help="Reset the account email.") +@click.option( + "--email", + prompt=True, + help="The old email address of the account whose email you need to reset", +) +@click.option("--new-email", prompt=True, help="the new email.") +@click.option("--email-confirm", prompt=True, help="the new email confirm.") def reset_email(email, new_email, email_confirm): """ Replace account email :return: """ if str(new_email).strip() != str(email_confirm).strip(): - click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red')) + click.echo(click.style("Sorry, new email and confirm email do not match.", fg="red")) return - account = db.session.query(Account). \ - filter(Account.email == email). \ - one_or_none() + account = db.session.query(Account).filter(Account.email == email).one_or_none() if not account: - click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red')) + click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red")) return try: email_validate(new_email) except: - click.echo( - click.style('sorry. {} is not a valid email. '.format(email), fg='red')) + click.echo(click.style("sorry. {} is not a valid email. ".format(email), fg="red")) return account.email = new_email db.session.commit() - click.echo(click.style('Congratulations!, email has been reset.', fg='green')) - - -@click.command('reset-encrypt-key-pair', help='Reset the asymmetric key pair of workspace for encrypt LLM credentials. ' - 'After the reset, all LLM credentials will become invalid, ' - 'requiring re-entry.' - 'Only support SELF_HOSTED mode.') -@click.confirmation_option(prompt=click.style('Are you sure you want to reset encrypt key pair?' - ' this operation cannot be rolled back!', fg='red')) + click.echo(click.style("Congratulations!, email has been reset.", fg="green")) + + +@click.command( + "reset-encrypt-key-pair", + help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. " + "After the reset, all LLM credentials will become invalid, " + "requiring re-entry." + "Only support SELF_HOSTED mode.", +) +@click.confirmation_option( + prompt=click.style( + "Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", + fg="red", + ) +) def reset_encrypt_key_pair(): """ Reset the encrypted key pair of workspace for encrypt LLM credentials. After the reset, all LLM credentials will become invalid, requiring re-entry. Only support SELF_HOSTED mode. """ - if current_app.config['EDITION'] != 'SELF_HOSTED': - click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red')) + if current_app.config["EDITION"] != "SELF_HOSTED": + click.echo(click.style("Sorry, only support SELF_HOSTED mode.", fg="red")) return tenants = db.session.query(Tenant).all() for tenant in tenants: if not tenant: - click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red')) + click.echo( + click.style( + "Sorry, no workspace found. Please enter /install to initialize.", + fg="red", + ) + ) return tenant.encrypt_public_key = generate_key_pair(tenant.id) - db.session.query(Provider).filter(Provider.provider_type == 'custom', Provider.tenant_id == tenant.id).delete() + db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete() db.session.commit() - click.echo(click.style('Congratulations! ' - 'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green')) - - -@click.command('vdb-migrate', help='migrate vector db.') -@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.') + click.echo( + click.style( + "Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id), + fg="green", + ) + ) + + +@click.command("vdb-migrate", help="migrate vector db.") +@click.option( + "--scope", + default="all", + prompt=False, + help="The scope of vector database to migrate, Default is All.", +) def vdb_migrate(scope: str): - if scope in ['knowledge', 'all']: + if scope in ["knowledge", "all"]: migrate_knowledge_vector_database() - if scope in ['annotation', 'all']: + if scope in ["annotation", "all"]: migrate_annotation_vector_database() @@ -138,7 +199,7 @@ def migrate_annotation_vector_database(): """ Migrate annotation datas to target vector database . """ - click.echo(click.style('Start migrate annotation data.', fg='green')) + click.echo(click.style("Start migrate annotation data.", fg="green")) create_count = 0 skipped_count = 0 total_count = 0 @@ -146,42 +207,48 @@ def migrate_annotation_vector_database(): while True: try: # get apps info - apps = db.session.query(App).filter( - App.status == 'normal' - ).order_by(App.created_at.desc()).paginate(page=page, per_page=50) + apps = ( + db.session.query(App) + .filter(App.status == "normal") + .order_by(App.created_at.desc()) + .paginate(page=page, per_page=50) + ) except NotFound: break page += 1 for app in apps: total_count = total_count + 1 - click.echo(f'Processing the {total_count} app {app.id}. ' - + f'{create_count} created, {skipped_count} skipped.') + click.echo( + f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped." + ) try: - click.echo('Create app annotation index: {}'.format(app.id)) - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app.id - ).first() + click.echo("Create app annotation index: {}".format(app.id)) + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first() + ) if not app_annotation_setting: skipped_count = skipped_count + 1 - click.echo('App annotation setting is disabled: {}'.format(app.id)) + click.echo("App annotation setting is disabled: {}".format(app.id)) continue # get dataset_collection_binding info - dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter( - DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id - ).first() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) + .first() + ) if not dataset_collection_binding: - click.echo('App annotation collection binding is not exist: {}'.format(app.id)) + click.echo("App annotation collection binding is not exist: {}".format(app.id)) continue annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all() dataset = Dataset( id=app.id, tenant_id=app.tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) documents = [] if annotations: @@ -191,101 +258,128 @@ def migrate_annotation_vector_database(): metadata={ "annotation_id": annotation.id, "app_id": app.id, - "doc_id": annotation.id - } + "doc_id": annotation.id, + }, ) documents.append(document) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) click.echo(f"Start to migrate annotation, app_id: {app.id}.") try: vector.delete() click.echo( - click.style(f'Successfully delete vector index for app: {app.id}.', - fg='green')) + click.style( + f"Successfully delete vector index for app: {app.id}.", + fg="green", + ) + ) except Exception as e: - click.echo( - click.style(f'Failed to delete vector index for app {app.id}.', - fg='red')) + click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red")) raise e if documents: try: - click.echo(click.style( - f'Start to created vector index with {len(documents)} annotations for app {app.id}.', - fg='green')) + click.echo( + click.style( + f"Start to created vector index with {len(documents)} annotations for app {app.id}.", + fg="green", + ) + ) vector.create(documents) click.echo( - click.style(f'Successfully created vector index for app {app.id}.', fg='green')) + click.style( + f"Successfully created vector index for app {app.id}.", + fg="green", + ) + ) except Exception as e: - click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red')) + click.echo( + click.style( + f"Failed to created vector index for app {app.id}.", + fg="red", + ) + ) raise e - click.echo(f'Successfully migrated app annotation {app.id}.') + click.echo(f"Successfully migrated app annotation {app.id}.") create_count += 1 except Exception as e: click.echo( - click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)), - fg='red')) + click.style( + "Create app annotation index error: {} {}".format(e.__class__.__name__, str(e)), + fg="red", + ) + ) continue click.echo( - click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.', - fg='green')) + click.style( + f"Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.", + fg="green", + ) + ) def migrate_knowledge_vector_database(): """ Migrate vector database datas to target vector database . """ - click.echo(click.style('Start migrate vector db.', fg='green')) + click.echo(click.style("Start migrate vector db.", fg="green")) create_count = 0 skipped_count = 0 total_count = 0 config = current_app.config - vector_type = config.get('VECTOR_STORE') + vector_type = config.get("VECTOR_STORE") page = 1 while True: try: - datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \ - .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) + datasets = ( + db.session.query(Dataset) + .filter(Dataset.indexing_technique == "high_quality") + .order_by(Dataset.created_at.desc()) + .paginate(page=page, per_page=50) + ) except NotFound: break page += 1 for dataset in datasets: total_count = total_count + 1 - click.echo(f'Processing the {total_count} dataset {dataset.id}. ' - + f'{create_count} created, {skipped_count} skipped.') + click.echo( + f"Processing the {total_count} dataset {dataset.id}. " + + f"{create_count} created, {skipped_count} skipped." + ) try: - click.echo('Create dataset vdb index: {}'.format(dataset.id)) + click.echo("Create dataset vdb index: {}".format(dataset.id)) if dataset.index_struct_dict: - if dataset.index_struct_dict['type'] == vector_type: + if dataset.index_struct_dict["type"] == vector_type: skipped_count = skipped_count + 1 continue - collection_name = '' + collection_name = "" if vector_type == "weaviate": dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) index_struct_dict = { - "type": 'weaviate', - "vector_store": {"class_prefix": collection_name} + "type": "weaviate", + "vector_store": {"class_prefix": collection_name}, } dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == "qdrant": if dataset.collection_binding_id: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ - one_or_none() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == dataset.collection_binding_id) + .one_or_none() + ) if dataset_collection_binding: collection_name = dataset_collection_binding.collection_name else: - raise ValueError('Dataset Collection Bindings is not exist!') + raise ValueError("Dataset Collection Bindings is not exist!") else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) index_struct_dict = { - "type": 'qdrant', - "vector_store": {"class_prefix": collection_name} + "type": "qdrant", + "vector_store": {"class_prefix": collection_name}, } dataset.index_struct = json.dumps(index_struct_dict) @@ -293,7 +387,15 @@ def migrate_knowledge_vector_database(): dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) index_struct_dict = { - "type": 'milvus', + "type": "milvus", + "vector_store": {"class_prefix": collection_name}, + } + dataset.index_struct = json.dumps(index_struct_dict) + elif vector_type == "relyt": + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + index_struct_dict = { + "type": 'relyt', "vector_store": {"class_prefix": collection_name} } dataset.index_struct = json.dumps(index_struct_dict) @@ -306,29 +408,43 @@ def migrate_knowledge_vector_database(): try: vector.delete() click.echo( - click.style(f'Successfully delete vector index {collection_name} for dataset {dataset.id}.', - fg='green')) + click.style( + f"Successfully delete vector index {collection_name} for dataset {dataset.id}.", + fg="green", + ) + ) except Exception as e: click.echo( - click.style(f'Failed to delete vector index {collection_name} for dataset {dataset.id}.', - fg='red')) + click.style( + f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", + fg="red", + ) + ) raise e - dataset_documents = db.session.query(DatasetDocument).filter( - DatasetDocument.dataset_id == dataset.id, - DatasetDocument.indexing_status == 'completed', - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).all() + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset.id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) documents = [] segments_count = 0 for dataset_document in dataset_documents: - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True - ).all() + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + ) + .all() + ) for segment in segments: document = Document( @@ -338,7 +454,7 @@ def migrate_knowledge_vector_database(): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) @@ -346,37 +462,55 @@ def migrate_knowledge_vector_database(): if documents: try: - click.echo(click.style( - f'Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.', - fg='green')) + click.echo( + click.style( + f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.", + fg="green", + ) + ) vector.create(documents) click.echo( - click.style(f'Successfully created vector index for dataset {dataset.id}.', fg='green')) + click.style( + f"Successfully created vector index for dataset {dataset.id}.", + fg="green", + ) + ) except Exception as e: - click.echo(click.style(f'Failed to created vector index for dataset {dataset.id}.', fg='red')) + click.echo( + click.style( + f"Failed to created vector index for dataset {dataset.id}.", + fg="red", + ) + ) raise e db.session.add(dataset) db.session.commit() - click.echo(f'Successfully migrated dataset {dataset.id}.') + click.echo(f"Successfully migrated dataset {dataset.id}.") create_count += 1 except Exception as e: db.session.rollback() click.echo( - click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), - fg='red')) + click.style( + "Create dataset index error: {} {}".format(e.__class__.__name__, str(e)), + fg="red", + ) + ) continue click.echo( - click.style(f'Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.', - fg='green')) + click.style( + f"Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.", + fg="green", + ) + ) -@click.command('convert-to-agent-apps', help='Convert Agent Assistant to Agent App.') +@click.command("convert-to-agent-apps", help="Convert Agent Assistant to Agent App.") def convert_to_agent_apps(): """ Convert Agent Assistant to Agent App. """ - click.echo(click.style('Start convert to agent apps.', fg='green')) + click.echo(click.style("Start convert to agent apps.", fg="green")) proceeded_app_ids = [] @@ -411,7 +545,7 @@ def convert_to_agent_apps(): break for app in apps: - click.echo('Converting app: {}'.format(app.id)) + click.echo("Converting app: {}".format(app.id)) try: app.mode = AppMode.AGENT_CHAT.value @@ -423,16 +557,25 @@ def convert_to_agent_apps(): ) db.session.commit() - click.echo(click.style('Converted app: {}'.format(app.id), fg='green')) + click.echo(click.style("Converted app: {}".format(app.id), fg="green")) except Exception as e: click.echo( - click.style('Convert app error: {} {}'.format(e.__class__.__name__, - str(e)), fg='red')) + click.style( + "Convert app error: {} {}".format(e.__class__.__name__, str(e)), + fg="red", + ) + ) - click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green')) + click.echo( + click.style( + "Congratulations! Converted {} agent apps.".format(len(proceeded_app_ids)), + fg="green", + ) + ) def register_commands(app): + app.cli.add_command(register) app.cli.add_command(reset_password) app.cli.add_command(reset_email) app.cli.add_command(reset_encrypt_key_pair) diff --git a/api/config.py b/api/config.py index 4a579367be12a..f210ac48f9dca 100644 --- a/api/config.py +++ b/api/config.py @@ -42,7 +42,7 @@ 'HOSTED_OPENAI_TRIAL_ENABLED': 'False', 'HOSTED_OPENAI_TRIAL_MODELS': 'gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003', 'HOSTED_OPENAI_PAID_ENABLED': 'False', - 'HOSTED_OPENAI_PAID_MODELS': 'gpt-4,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-4-0125-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,gpt-3.5-turbo-instruct,text-davinci-003', + 'HOSTED_OPENAI_PAID_MODELS': 'gpt-4,gpt-4-turbo-preview,gpt-4-turbo-2024-04-09,gpt-4-1106-preview,gpt-4-0125-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,gpt-3.5-turbo-instruct,text-davinci-003', 'HOSTED_AZURE_OPENAI_ENABLED': 'False', 'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200, 'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000, @@ -64,9 +64,10 @@ 'ETL_TYPE': 'dify', 'KEYWORD_STORE': 'jieba', 'BATCH_UPLOAD_LIMIT': 20, - 'CODE_EXECUTION_ENDPOINT': '', - 'CODE_EXECUTION_API_KEY': '', + 'CODE_EXECUTION_ENDPOINT': 'http://sandbox:8194', + 'CODE_EXECUTION_API_KEY': 'dify-sandbox', 'TOOL_ICON_CACHE_MAX_AGE': 3600, + 'MILVUS_DATABASE': 'default', 'KEYWORD_DATA_SOURCE_TYPE': 'database', } @@ -98,7 +99,7 @@ def __init__(self): # ------------------------ # General Configurations. # ------------------------ - self.CURRENT_VERSION = "0.6.0" + self.CURRENT_VERSION = "0.6.3" self.COMMIT_SHA = get_env('COMMIT_SHA') self.EDITION = "SELF_HOSTED" self.DEPLOY_ENV = get_env('DEPLOY_ENV') @@ -197,7 +198,7 @@ def __init__(self): # ------------------------ # Vector Store Configurations. - # Currently, only support: qdrant, milvus, zilliz, weaviate + # Currently, only support: qdrant, milvus, zilliz, weaviate, relyt # ------------------------ self.VECTOR_STORE = get_env('VECTOR_STORE') self.KEYWORD_STORE = get_env('KEYWORD_STORE') @@ -212,6 +213,7 @@ def __init__(self): self.MILVUS_USER = get_env('MILVUS_USER') self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD') self.MILVUS_SECURE = get_env('MILVUS_SECURE') + self.MILVUS_DATABASE = get_env('MILVUS_DATABASE') # weaviate settings self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT') @@ -219,6 +221,13 @@ def __init__(self): self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED') self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE')) + # relyt settings + self.RELYT_HOST = get_env('RELYT_HOST') + self.RELYT_PORT = get_env('RELYT_PORT') + self.RELYT_USER = get_env('RELYT_USER') + self.RELYT_PASSWORD = get_env('RELYT_PASSWORD') + self.RELYT_DATABASE = get_env('RELYT_DATABASE') + # ------------------------ # Mail Configurations. # ------------------------ diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index f7890de1cb605..71b52d5cebe84 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timezone import pytz from flask_login import current_user @@ -262,7 +262,7 @@ def _get_conversation(app_model, conversation_id): raise NotFound("Conversation Not Exists.") if not conversation.read_at: - conversation.read_at = datetime.utcnow() + conversation.read_at = datetime.now(timezone.utc).replace(tzinfo=None) conversation.read_account_id = current_user.id db.session.commit() diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 20e028af9936e..8efb55cdb64ed 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,6 +1,6 @@ import base64 +import datetime import secrets -from datetime import datetime from flask_restful import Resource, reqparse @@ -66,7 +66,7 @@ def post(self): account.timezone = args['timezone'] account.interface_theme = 'light' account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.utcnow() + account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() return {'result': 'success'} diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 05b1c36873a55..e5b80e9a57be3 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -1,5 +1,5 @@ import logging -from datetime import datetime +from datetime import datetime, timezone from typing import Optional import requests @@ -73,7 +73,7 @@ def get(self, provider: str): if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.utcnow() + account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() TenantService.create_owner_tenant_if_not_exist(account) diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index f3e639c6acfd9..8b210cc756bc0 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -80,7 +80,7 @@ def patch(self, binding_id, action): if action == 'enable': if data_source_binding.disabled: data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.utcnow() + data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(data_source_binding) db.session.commit() else: @@ -89,7 +89,7 @@ def patch(self, binding_id, action): if action == 'disable': if not data_source_binding.disabled: data_source_binding.disabled = True - data_source_binding.updated_at = datetime.datetime.utcnow() + data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(data_source_binding) db.session.commit() else: diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index c383cdc762a6f..3d6daa76825c3 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timezone from flask import request from flask_login import current_user @@ -637,7 +637,7 @@ def patch(self, dataset_id, document_id, action): raise InvalidActionError('Document not in indexing state.') document.paused_by = current_user.id - document.paused_at = datetime.utcnow() + document.paused_at = datetime.now(timezone.utc).replace(tzinfo=None) document.is_paused = True db.session.commit() @@ -717,7 +717,7 @@ def put(self, dataset_id, document_id): document.doc_metadata[key] = value document.doc_type = doc_type - document.updated_at = datetime.utcnow() + document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() return {'result': 'success', 'message': 'Document metadata updated.'}, 200 @@ -755,7 +755,7 @@ def patch(self, dataset_id, document_id, action): document.enabled = True document.disabled_at = None document.disabled_by = None - document.updated_at = datetime.utcnow() + document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() # Set cache to prevent indexing the same document multiple times @@ -772,9 +772,9 @@ def patch(self, dataset_id, document_id, action): raise InvalidActionError('Document already disabled.') document.enabled = False - document.disabled_at = datetime.utcnow() + document.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None) document.disabled_by = current_user.id - document.updated_at = datetime.utcnow() + document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() # Set cache to prevent indexing the same document multiple times @@ -789,9 +789,9 @@ def patch(self, dataset_id, document_id, action): raise InvalidActionError('Document already archived.') document.archived = True - document.archived_at = datetime.utcnow() + document.archived_at = datetime.now(timezone.utc).replace(tzinfo=None) document.archived_by = current_user.id - document.updated_at = datetime.utcnow() + document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() if document.enabled: @@ -808,7 +808,7 @@ def patch(self, dataset_id, document_id, action): document.archived = False document.archived_at = None document.archived_by = None - document.updated_at = datetime.utcnow() + document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() # Set cache to prevent indexing the same document multiple times diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 7b58120a584e9..0a88a0d8d44d7 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -1,5 +1,5 @@ import uuid -from datetime import datetime +from datetime import datetime, timezone import pandas as pd from flask import request @@ -192,7 +192,7 @@ def patch(self, dataset_id, segment_id, action): raise InvalidActionError("Segment is already disabled.") segment.enabled = False - segment.disabled_at = datetime.utcnow() + segment.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None) segment.disabled_by = current_user.id db.session.commit() diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 292b4ed2a0e9d..869b56e13bf93 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,5 +1,5 @@ import logging -from datetime import datetime +from datetime import datetime, timezone from flask_login import current_user from flask_restful import reqparse @@ -47,7 +47,7 @@ def post(self, installed_app): streaming = args['response_mode'] == 'streaming' args['auto_generate_name'] = False - installed_app.last_used_at = datetime.utcnow() + installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() try: @@ -110,7 +110,7 @@ def post(self, installed_app): args['auto_generate_name'] = False - installed_app.last_used_at = datetime.utcnow() + installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() try: diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 7892840aebb24..ea0fa4e17e9d4 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -6,6 +6,7 @@ from controllers.console import api from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource +from core.app.entities.app_invoke_entities import InvokeFrom from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value from models.model import AppMode @@ -39,8 +40,8 @@ def get(self, installed_app): user=current_user, last_id=args['last_id'], limit=args['limit'], + invoke_from=InvokeFrom.EXPLORE, pinned=pinned, - exclude_debug_conversation=True ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 7d6231270f23d..ec7bbed3074ad 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timezone from flask_login import current_user from flask_restful import Resource, inputs, marshal_with, reqparse @@ -81,7 +81,7 @@ def post(self): tenant_id=current_tenant_id, app_owner_tenant_id=app.tenant_id, is_pinned=False, - last_used_at=datetime.utcnow() + last_used_at=datetime.now(timezone.utc).replace(tzinfo=None) ) db.session.add(new_installed_app) db.session.commit() diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 656a4d4cee6af..198409bba78f7 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -1,4 +1,4 @@ -from datetime import datetime +import datetime import pytz from flask import current_app, request @@ -59,7 +59,7 @@ def post(self): raise InvalidInvitationCodeError() invitation_code.status = 'used' - invitation_code.used_at = datetime.utcnow() + invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_account_id = account.id @@ -67,7 +67,7 @@ def post(self): account.timezone = args['timezone'] account.interface_theme = 'light' account.status = 'active' - account.initialized_at = datetime.utcnow() + account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() return {'result': 'success'} diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index ccf743371ac96..bccce9b55bcec 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,14 +1,11 @@ -import json from flask import current_app -from flask_restful import fields, marshal_with, Resource +from flask_restful import Resource, fields, marshal_with from controllers.service_api import api from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.wraps import validate_app_token -from extensions.ext_database import db -from models.model import App, AppModelConfig, AppMode -from models.tools import ApiToolProvider +from models.model import App, AppMode from services.app_service import AppService @@ -92,6 +89,16 @@ def get(self, app_model: App): """Get app meta""" return AppService().get_app_meta(app_model) +class AppInfoApi(Resource): + @validate_app_token + def get(self, app_model: App): + """Get app infomation""" + return { + 'name':app_model.name, + 'description':app_model.description + } + api.add_resource(AppParameterApi, '/parameters') api.add_resource(AppMetaApi, '/meta') +api.add_resource(AppInfoApi, '/info') diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index fc60f94ec9bc5..02158f8b56d27 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -6,6 +6,7 @@ from controllers.service_api import api from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token +from core.app.entities.app_invoke_entities import InvokeFrom from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value from models.model import App, AppMode, EndUser @@ -27,7 +28,13 @@ def get(self, app_model: App, end_user: EndUser): args = parser.parse_args() try: - return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit']) + return ConversationService.pagination_by_last_id( + app_model=app_model, + user=end_user, + last_id=args['last_id'], + limit=args['limit'], + invoke_from=InvokeFrom.SERVICE_API + ) except services.errors.conversation.LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index a75583469e62b..70733d63f4046 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from datetime import datetime +from datetime import datetime, timezone from enum import Enum from functools import wraps from typing import Optional @@ -183,7 +183,7 @@ def validate_and_get_api_token(scope=None): if not api_token: raise Unauthorized("Access token is invalid") - api_token.last_used_at = datetime.utcnow() + api_token.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() return api_token diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index bbc57c7d61b7c..b83ea3a52596b 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -5,6 +5,7 @@ from controllers.web import api from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource +from core.app.entities.app_invoke_entities import InvokeFrom from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value from models.model import AppMode @@ -37,7 +38,8 @@ def get(self, app_model, end_user): user=end_user, last_id=args['last_id'], limit=args['limit'], - pinned=pinned + invoke_from=InvokeFrom.WEB_APP, + pinned=pinned, ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 8955ad2d1af9b..e5b4b9a4cdf72 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -1,10 +1,11 @@ import json import logging import uuid -from datetime import datetime +from datetime import datetime, timezone from typing import Optional, Union, cast from core.agent.entities import AgentEntity, AgentToolEntity +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_runner import AppRunner @@ -14,6 +15,7 @@ ) from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.file.message_file_parser import MessageFileParser from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage @@ -22,6 +24,7 @@ PromptMessage, PromptMessageTool, SystemPromptMessage, + TextPromptMessageContent, ToolPromptMessage, UserPromptMessage, ) @@ -37,7 +40,7 @@ from core.tools.tool.tool import Tool from core.tools.tool_manager import ToolManager from extensions.ext_database import db -from models.model import Message, MessageAgentThought +from models.model import Conversation, Message, MessageAgentThought from models.tools import ToolConversationVariables logger = logging.getLogger(__name__) @@ -45,6 +48,7 @@ class BaseAgentRunner(AppRunner): def __init__(self, tenant_id: str, application_generate_entity: AgentChatAppGenerateEntity, + conversation: Conversation, app_config: AgentChatAppConfig, model_config: ModelConfigWithCredentialsEntity, config: AgentEntity, @@ -72,6 +76,7 @@ def __init__(self, tenant_id: str, """ self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity + self.conversation = conversation self.app_config = app_config self.model_config = model_config self.config = config @@ -118,6 +123,12 @@ def __init__(self, tenant_id: str, else: self.stream_tool_call = False + # check if model supports vision + if model_schema and ModelFeature.VISION in (model_schema.features or []): + self.files = application_generate_entity.files + else: + self.files = [] + def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \ -> AgentChatAppGenerateEntity: """ @@ -227,6 +238,34 @@ def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRe return prompt_tool + def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]: + """ + Init tools + """ + tool_instances = {} + prompt_messages_tools = [] + + for tool in self.app_config.agent.tools if self.app_config.agent else []: + try: + prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) + except Exception: + # api tool may be deleted + continue + # save tool entity + tool_instances[tool.tool_name] = tool_entity + # save prompt tool + prompt_messages_tools.append(prompt_tool) + + # convert dataset tools into ModelRuntime Tool format + for dataset_tool in self.dataset_tools: + prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool) + # save prompt tool + prompt_messages_tools.append(prompt_tool) + # save tool entity + tool_instances[dataset_tool.identity.name] = dataset_tool + + return tool_instances, prompt_messages_tools + def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool: """ update prompt message tool @@ -314,7 +353,7 @@ def save_agent_thought(self, tool_name: str, tool_input: Union[str, dict], thought: str, - observation: Union[str, str], + observation: Union[str, dict], tool_invoke_meta: Union[str, dict], answer: str, messages_ids: list[str], @@ -401,7 +440,7 @@ def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variab ToolConversationVariables.conversation_id == self.message.conversation_id, ).first() - db_variables.updated_at = datetime.utcnow() + db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) db.session.commit() db.session.close() @@ -412,15 +451,19 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P """ result = [] # check if there is a system message in the beginning of the conversation - if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage): - result.append(prompt_messages[0]) + for prompt_message in prompt_messages: + if isinstance(prompt_message, SystemPromptMessage): + result.append(prompt_message) messages: list[Message] = db.session.query(Message).filter( Message.conversation_id == self.message.conversation_id, ).order_by(Message.created_at.asc()).all() for message in messages: - result.append(UserPromptMessage(content=message.query)) + if message.id == self.message.id: + continue + + result.append(self.organize_agent_user_prompt(message)) agent_thoughts: list[MessageAgentThought] = message.agent_thoughts if agent_thoughts: for agent_thought in agent_thoughts: @@ -471,3 +514,32 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P db.session.close() return result + + def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: + message_file_parser = MessageFileParser( + tenant_id=self.tenant_id, + app_id=self.app_config.app_id, + ) + + files = message.message_files + if files: + file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) + + if file_extra_config: + file_objs = message_file_parser.transform_message_files( + files, + file_extra_config + ) + else: + file_objs = [] + + if not file_objs: + return UserPromptMessage(content=message.query) + else: + prompt_message_contents = [TextPromptMessageContent(data=message.query)] + for file_obj in file_objs: + prompt_message_contents.append(file_obj.prompt_message_content) + + return UserPromptMessage(content=prompt_message_contents) + else: + return UserPromptMessage(content=message.query) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index d57f15638ccc7..ed55d1b02203b 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -1,33 +1,36 @@ import json -import re +from abc import ABC, abstractmethod from collections.abc import Generator -from typing import Literal, Union +from typing import Union from core.agent.base_agent_runner import BaseAgentRunner -from core.agent.entities import AgentPromptEntity, AgentScratchpadUnit +from core.agent.entities import AgentScratchpadUnit +from core.agent.output_parser.cot_output_parser import CotAgentOutputParser from core.app.apps.base_app_queue_manager import PublishFrom from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, - PromptMessageTool, - SystemPromptMessage, ToolPromptMessage, UserPromptMessage, ) -from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.tool_entities import ToolInvokeMeta +from core.tools.tool.tool import Tool from core.tools.tool_engine import ToolEngine -from models.model import Conversation, Message +from models.model import Message -class CotAgentRunner(BaseAgentRunner): +class CotAgentRunner(BaseAgentRunner, ABC): _is_first_iteration = True _ignore_observation_providers = ['wenxin'] + _historic_prompt_messages: list[PromptMessage] = None + _agent_scratchpad: list[AgentScratchpadUnit] = None + _instruction: str = None + _query: str = None + _prompt_messages_tools: list[PromptMessage] = None - def run(self, conversation: Conversation, - message: Message, + def run(self, message: Message, query: str, inputs: dict[str, str], ) -> Union[Generator, LLMResult]: @@ -36,9 +39,7 @@ def run(self, conversation: Conversation, """ app_generate_entity = self.application_generate_entity self._repack_app_generate_entity(app_generate_entity) - - agent_scratchpad: list[AgentScratchpadUnit] = [] - self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages) + self._init_react_state(query) # check model mode if 'Observation' not in app_generate_entity.model_config.stop: @@ -47,38 +48,19 @@ def run(self, conversation: Conversation, app_config = self.app_config - # override inputs + # init instruction inputs = inputs or {} instruction = app_config.prompt_template.simple_prompt_template - instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) + self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) iteration_step = 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 - prompt_messages = self.history_prompt_messages - # convert tools into ModelRuntime Tool format - prompt_messages_tools: list[PromptMessageTool] = [] - tool_instances = {} - for tool in app_config.agent.tools if app_config.agent else []: - try: - prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) - except Exception: - # api tool may be deleted - continue - # save tool entity - tool_instances[tool.tool_name] = tool_entity - # save prompt tool - prompt_messages_tools.append(prompt_tool) - - # convert dataset tools into ModelRuntime Tool format - for dataset_tool in self.dataset_tools: - prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool) - # save prompt tool - prompt_messages_tools.append(prompt_tool) - # save tool entity - tool_instances[dataset_tool.identity.name] = dataset_tool + tool_instances, self._prompt_messages_tools = self._init_prompt_tools() + prompt_messages = self._organize_prompt_messages() + function_call_state = True llm_usage = { 'usage': None @@ -103,7 +85,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if iteration_step == max_iteration_steps: # the last iteration, remove all tools - prompt_messages_tools = [] + self._prompt_messages_tools = [] message_file_ids = [] @@ -120,18 +102,8 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): agent_thought_id=agent_thought.id ), PublishFrom.APPLICATION_MANAGER) - # update prompt messages - prompt_messages = self._organize_cot_prompt_messages( - mode=app_generate_entity.model_config.mode, - prompt_messages=prompt_messages, - tools=prompt_messages_tools, - agent_scratchpad=agent_scratchpad, - agent_prompt_message=app_config.agent.prompt, - instruction=instruction, - input=query - ) - # recalc llm max tokens + prompt_messages = self._organize_prompt_messages() self.recalc_llm_max_tokens(self.model_config, prompt_messages) # invoke model chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( @@ -149,7 +121,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): raise ValueError("failed to invoke llm") usage_dict = {} - react_chunks = self._handle_stream_react(chunks, usage_dict) + react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks) scratchpad = AgentScratchpadUnit( agent_response='', thought='', @@ -165,30 +137,12 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): ), PublishFrom.APPLICATION_MANAGER) for chunk in react_chunks: - if isinstance(chunk, dict): - scratchpad.agent_response += json.dumps(chunk) - try: - if scratchpad.action: - raise Exception("") - scratchpad.action_str = json.dumps(chunk) - scratchpad.action = AgentScratchpadUnit.Action( - action_name=chunk['action'], - action_input=chunk['action_input'] - ) - except: - scratchpad.thought += json.dumps(chunk) - yield LLMResultChunk( - model=self.model_config.model, - prompt_messages=prompt_messages, - system_fingerprint='', - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage( - content=json.dumps(chunk, ensure_ascii=False) # if ensure_ascii=True, the text in webui maybe garbled text - ), - usage=None - ) - ) + if isinstance(chunk, AgentScratchpadUnit.Action): + action = chunk + # detect action + scratchpad.agent_response += json.dumps(chunk.dict()) + scratchpad.action_str = json.dumps(chunk.dict()) + scratchpad.action = action else: scratchpad.agent_response += chunk scratchpad.thought += chunk @@ -206,27 +160,29 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): ) scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you' - agent_scratchpad.append(scratchpad) - + self._agent_scratchpad.append(scratchpad) + # get llm usage if 'usage' in usage_dict: increase_usage(llm_usage, usage_dict['usage']) else: usage_dict['usage'] = LLMUsage.empty_usage() - self.save_agent_thought(agent_thought=agent_thought, - tool_name=scratchpad.action.action_name if scratchpad.action else '', - tool_input={ - scratchpad.action.action_name: scratchpad.action.action_input - } if scratchpad.action else '', - tool_invoke_meta={}, - thought=scratchpad.thought, - observation='', - answer=scratchpad.agent_response, - messages_ids=[], - llm_usage=usage_dict['usage']) + self.save_agent_thought( + agent_thought=agent_thought, + tool_name=scratchpad.action.action_name if scratchpad.action else '', + tool_input={ + scratchpad.action.action_name: scratchpad.action.action_input + } if scratchpad.action else {}, + tool_invoke_meta={}, + thought=scratchpad.thought, + observation='', + answer=scratchpad.agent_response, + messages_ids=[], + llm_usage=usage_dict['usage'] + ) - if scratchpad.action and scratchpad.action.action_name.lower() != "final answer": + if not scratchpad.is_final(): self.queue_manager.publish(QueueAgentThoughtEvent( agent_thought_id=agent_thought.id ), PublishFrom.APPLICATION_MANAGER) @@ -238,106 +194,43 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if scratchpad.action.action_name.lower() == "final answer": # action is final answer, return final answer directly try: - final_answer = scratchpad.action.action_input if \ - isinstance(scratchpad.action.action_input, str) else \ - json.dumps(scratchpad.action.action_input) + if isinstance(scratchpad.action.action_input, dict): + final_answer = json.dumps(scratchpad.action.action_input) + elif isinstance(scratchpad.action.action_input, str): + final_answer = scratchpad.action.action_input + else: + final_answer = f'{scratchpad.action.action_input}' except json.JSONDecodeError: final_answer = f'{scratchpad.action.action_input}' else: function_call_state = True - # action is tool call, invoke tool - tool_call_name = scratchpad.action.action_name - tool_call_args = scratchpad.action.action_input - tool_instance = tool_instances.get(tool_call_name) - if not tool_instance: - answer = f"there is not a tool named {tool_call_name}" - self.save_agent_thought( - agent_thought=agent_thought, - tool_name='', - tool_input='', - tool_invoke_meta=ToolInvokeMeta.error_instance( - f"there is not a tool named {tool_call_name}" - ).to_dict(), - thought=None, - observation={ - tool_call_name: answer - }, - answer=answer, - messages_ids=[] - ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) - else: - if isinstance(tool_call_args, str): - try: - tool_call_args = json.loads(tool_call_args) - except json.JSONDecodeError: - pass - - # invoke tool - tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( - tool=tool_instance, - tool_parameters=tool_call_args, - user_id=self.user_id, - tenant_id=self.tenant_id, - message=self.message, - invoke_from=self.application_generate_entity.invoke_from, - agent_tool_callback=self.agent_callback - ) - # publish files - for message_file, save_as in message_files: - if save_as: - self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as) - - # publish message file - self.queue_manager.publish(QueueMessageFileEvent( - message_file_id=message_file.id - ), PublishFrom.APPLICATION_MANAGER) - # add message file ids - message_file_ids.append(message_file.id) - - # publish files - for message_file, save_as in message_files: - if save_as: - self.variables_pool.set_file(tool_name=tool_call_name, - value=message_file.id, - name=save_as) - self.queue_manager.publish(QueueMessageFileEvent( - message_file_id=message_file.id - ), PublishFrom.APPLICATION_MANAGER) - - message_file_ids = [message_file.id for message_file, _ in message_files] - - observation = tool_invoke_response - - # save scratchpad - scratchpad.observation = observation - - # save agent thought - self.save_agent_thought( - agent_thought=agent_thought, - tool_name=tool_call_name, - tool_input={ - tool_call_name: tool_call_args - }, - tool_invoke_meta={ - tool_call_name: tool_invoke_meta.to_dict() - }, - thought=None, - observation={ - tool_call_name: observation - }, - answer=scratchpad.agent_response, - messages_ids=message_file_ids, - ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + tool_invoke_response, tool_invoke_meta = self._handle_invoke_action( + action=scratchpad.action, + tool_instances=tool_instances, + message_file_ids=message_file_ids + ) + scratchpad.observation = tool_invoke_response + scratchpad.agent_response = tool_invoke_response + + self.save_agent_thought( + agent_thought=agent_thought, + tool_name=scratchpad.action.action_name, + tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, + thought=scratchpad.thought, + observation={scratchpad.action.action_name: tool_invoke_response}, + tool_invoke_meta=tool_invoke_meta.to_dict(), + answer=scratchpad.agent_response, + messages_ids=message_file_ids, + llm_usage=usage_dict['usage'] + ) + + self.queue_manager.publish(QueueAgentThoughtEvent( + agent_thought_id=agent_thought.id + ), PublishFrom.APPLICATION_MANAGER) # update prompt tool message - for prompt_tool in prompt_messages_tools: + for prompt_tool in self._prompt_messages_tools: self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) iteration_step += 1 @@ -379,96 +272,63 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): system_fingerprint='' )), PublishFrom.APPLICATION_MANAGER) - def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \ - -> Generator[Union[str, dict], None, None]: - def parse_json(json_str): + def _handle_invoke_action(self, action: AgentScratchpadUnit.Action, + tool_instances: dict[str, Tool], + message_file_ids: list[str]) -> tuple[str, ToolInvokeMeta]: + """ + handle invoke action + :param action: action + :param tool_instances: tool instances + :return: observation, meta + """ + # action is tool call, invoke tool + tool_call_name = action.action_name + tool_call_args = action.action_input + tool_instance = tool_instances.get(tool_call_name) + + if not tool_instance: + answer = f"there is not a tool named {tool_call_name}" + return answer, ToolInvokeMeta.error_instance(answer) + + if isinstance(tool_call_args, str): try: - return json.loads(json_str.strip()) - except: - return json_str - - def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]: - code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL) - if not code_blocks: - return - for block in code_blocks: - json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE) - yield parse_json(json_text) - - code_block_cache = '' - code_block_delimiter_count = 0 - in_code_block = False - json_cache = '' - json_quote_count = 0 - in_json = False - got_json = False - - for response in llm_response: - response = response.delta.message.content - if not isinstance(response, str): - continue + tool_call_args = json.loads(tool_call_args) + except json.JSONDecodeError: + pass + + # invoke tool + tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( + tool=tool_instance, + tool_parameters=tool_call_args, + user_id=self.user_id, + tenant_id=self.tenant_id, + message=self.message, + invoke_from=self.application_generate_entity.invoke_from, + agent_tool_callback=self.agent_callback + ) - # stream - index = 0 - while index < len(response): - steps = 1 - delta = response[index:index+steps] - if delta == '`': - code_block_cache += delta - code_block_delimiter_count += 1 - else: - if not in_code_block: - if code_block_delimiter_count > 0: - yield code_block_cache - code_block_cache = '' - else: - code_block_cache += delta - code_block_delimiter_count = 0 - - if code_block_delimiter_count == 3: - if in_code_block: - yield from extra_json_from_code_block(code_block_cache) - code_block_cache = '' - - in_code_block = not in_code_block - code_block_delimiter_count = 0 - - if not in_code_block: - # handle single json - if delta == '{': - json_quote_count += 1 - in_json = True - json_cache += delta - elif delta == '}': - json_cache += delta - if json_quote_count > 0: - json_quote_count -= 1 - if json_quote_count == 0: - in_json = False - got_json = True - index += steps - continue - else: - if in_json: - json_cache += delta - - if got_json: - got_json = False - yield parse_json(json_cache) - json_cache = '' - json_quote_count = 0 - in_json = False - - if not in_code_block and not in_json: - yield delta.replace('`', '') - - index += steps - - if code_block_cache: - yield code_block_cache - - if json_cache: - yield parse_json(json_cache) + # publish files + for message_file, save_as in message_files: + if save_as: + self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as) + + # publish message file + self.queue_manager.publish(QueueMessageFileEvent( + message_file_id=message_file.id + ), PublishFrom.APPLICATION_MANAGER) + # add message file ids + message_file_ids.append(message_file.id) + + return tool_invoke_response, tool_invoke_meta + + def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action: + """ + convert dict to action + """ + return AgentScratchpadUnit.Action( + action_name=action['action'], + action_input=action['action_input'] + ) def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str: """ @@ -482,15 +342,46 @@ def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dic return instruction - def _init_agent_scratchpad(self, - agent_scratchpad: list[AgentScratchpadUnit], - messages: list[PromptMessage] - ) -> list[AgentScratchpadUnit]: + def _init_react_state(self, query) -> None: """ init agent scratchpad """ + self._query = query + self._agent_scratchpad = [] + self._historic_prompt_messages = self._organize_historic_prompt_messages() + + @abstractmethod + def _organize_prompt_messages(self) -> list[PromptMessage]: + """ + organize prompt messages + """ + + def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: + """ + format assistant message + """ + message = '' + for scratchpad in agent_scratchpad: + if scratchpad.is_final(): + message += f"Final Answer: {scratchpad.agent_response}" + else: + message += f"Thought: {scratchpad.thought}\n\n" + if scratchpad.action_str: + message += f"Action: {scratchpad.action_str}\n\n" + if scratchpad.observation: + message += f"Observation: {scratchpad.observation}\n\n" + + return message + + def _organize_historic_prompt_messages(self) -> list[PromptMessage]: + """ + organize historic prompt messages + """ + result: list[PromptMessage] = [] + scratchpad: list[AgentScratchpadUnit] = [] current_scratchpad: AgentScratchpadUnit = None - for message in messages: + + for message in self.history_prompt_messages: if isinstance(message, AssistantPromptMessage): current_scratchpad = AgentScratchpadUnit( agent_response=message.content, @@ -505,186 +396,29 @@ def _init_agent_scratchpad(self, action_name=message.tool_calls[0].function.name, action_input=json.loads(message.tool_calls[0].function.arguments) ) + current_scratchpad.action_str = json.dumps( + current_scratchpad.action.to_dict() + ) except: pass - - agent_scratchpad.append(current_scratchpad) + + scratchpad.append(current_scratchpad) elif isinstance(message, ToolPromptMessage): if current_scratchpad: current_scratchpad.observation = message.content - - return agent_scratchpad + elif isinstance(message, UserPromptMessage): + result.append(message) - def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"], - agent_prompt_message: AgentPromptEntity, - ): - """ - check chain of thought prompt messages, a standard prompt message is like: - Respond to the human as helpfully and accurately as possible. - - {{instruction}} - - You have access to the following tools: + if scratchpad: + result.append(AssistantPromptMessage( + content=self._format_assistant_message(scratchpad) + )) - {{tools}} + scratchpad = [] - Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). - Valid action values: "Final Answer" or {{tool_names}} - - Provide only ONE action per $JSON_BLOB, as shown: - - ``` - { - "action": $TOOL_NAME, - "action_input": $ACTION_INPUT - } - ``` - """ - - # parse agent prompt message - first_prompt = agent_prompt_message.first_prompt - next_iteration = agent_prompt_message.next_iteration - - if not isinstance(first_prompt, str) or not isinstance(next_iteration, str): - raise ValueError("first_prompt or next_iteration is required in CoT agent mode") - - # check instruction, tools, and tool_names slots - if not first_prompt.find("{{instruction}}") >= 0: - raise ValueError("{{instruction}} is required in first_prompt") - if not first_prompt.find("{{tools}}") >= 0: - raise ValueError("{{tools}} is required in first_prompt") - if not first_prompt.find("{{tool_names}}") >= 0: - raise ValueError("{{tool_names}} is required in first_prompt") + if scratchpad: + result.append(AssistantPromptMessage( + content=self._format_assistant_message(scratchpad) + )) - if mode == "completion": - if not first_prompt.find("{{query}}") >= 0: - raise ValueError("{{query}} is required in first_prompt") - if not first_prompt.find("{{agent_scratchpad}}") >= 0: - raise ValueError("{{agent_scratchpad}} is required in first_prompt") - - if mode == "completion": - if not next_iteration.find("{{observation}}") >= 0: - raise ValueError("{{observation}} is required in next_iteration") - - def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: - """ - convert agent scratchpad list to str - """ - next_iteration = self.app_config.agent.prompt.next_iteration - - result = '' - for scratchpad in agent_scratchpad: - result += (scratchpad.thought or '') + (scratchpad.action_str or '') + \ - next_iteration.replace("{{observation}}", scratchpad.observation or 'It seems that no response is available') - - return result - - def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"], - prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - agent_scratchpad: list[AgentScratchpadUnit], - agent_prompt_message: AgentPromptEntity, - instruction: str, - input: str, - ) -> list[PromptMessage]: - """ - organize chain of thought prompt messages, a standard prompt message is like: - Respond to the human as helpfully and accurately as possible. - - {{instruction}} - - You have access to the following tools: - - {{tools}} - - Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). - Valid action values: "Final Answer" or {{tool_names}} - - Provide only ONE action per $JSON_BLOB, as shown: - - ``` - {{{{ - "action": $TOOL_NAME, - "action_input": $ACTION_INPUT - }}}} - ``` - """ - - self._check_cot_prompt_messages(mode, agent_prompt_message) - - # parse agent prompt message - first_prompt = agent_prompt_message.first_prompt - - # parse tools - tools_str = self._jsonify_tool_prompt_messages(tools) - - # parse tools name - tool_names = '"' + '","'.join([tool.name for tool in tools]) + '"' - - # get system message - system_message = first_prompt.replace("{{instruction}}", instruction) \ - .replace("{{tools}}", tools_str) \ - .replace("{{tool_names}}", tool_names) - - # organize prompt messages - if mode == "chat": - # override system message - overridden = False - prompt_messages = prompt_messages.copy() - for prompt_message in prompt_messages: - if isinstance(prompt_message, SystemPromptMessage): - prompt_message.content = system_message - overridden = True - break - - # convert tool prompt messages to user prompt messages - for idx, prompt_message in enumerate(prompt_messages): - if isinstance(prompt_message, ToolPromptMessage): - prompt_messages[idx] = UserPromptMessage( - content=prompt_message.content - ) - - if not overridden: - prompt_messages.insert(0, SystemPromptMessage( - content=system_message, - )) - - # add assistant message - if len(agent_scratchpad) > 0 and not self._is_first_iteration: - prompt_messages.append(AssistantPromptMessage( - content=(agent_scratchpad[-1].thought or '') + (agent_scratchpad[-1].action_str or ''), - )) - - # add user message - if len(agent_scratchpad) > 0 and not self._is_first_iteration: - prompt_messages.append(UserPromptMessage( - content=(agent_scratchpad[-1].observation or 'It seems that no response is available'), - )) - - self._is_first_iteration = False - - return prompt_messages - elif mode == "completion": - # parse agent scratchpad - agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad) - self._is_first_iteration = False - # parse prompt messages - return [UserPromptMessage( - content=first_prompt.replace("{{instruction}}", instruction) - .replace("{{tools}}", tools_str) - .replace("{{tool_names}}", tool_names) - .replace("{{query}}", input) - .replace("{{agent_scratchpad}}", agent_scratchpad_str), - )] - else: - raise ValueError(f"mode {mode} is not supported") - - def _jsonify_tool_prompt_messages(self, tools: list[PromptMessageTool]) -> str: - """ - jsonify tool prompt messages - """ - tools = jsonable_encoder(tools) - try: - return json.dumps(tools, ensure_ascii=False) - except json.JSONDecodeError: - return json.dumps(tools) + return result \ No newline at end of file diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py new file mode 100644 index 0000000000000..a904f3e64175c --- /dev/null +++ b/api/core/agent/cot_chat_agent_runner.py @@ -0,0 +1,71 @@ +import json + +from core.agent.cot_agent_runner import CotAgentRunner +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.utils.encoders import jsonable_encoder + + +class CotChatAgentRunner(CotAgentRunner): + def _organize_system_prompt(self) -> SystemPromptMessage: + """ + Organize system prompt + """ + prompt_entity = self.app_config.agent.prompt + first_prompt = prompt_entity.first_prompt + + system_prompt = first_prompt \ + .replace("{{instruction}}", self._instruction) \ + .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \ + .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools])) + + return SystemPromptMessage(content=system_prompt) + + def _organize_prompt_messages(self) -> list[PromptMessage]: + """ + Organize + """ + # organize system prompt + system_message = self._organize_system_prompt() + + # organize historic prompt messages + historic_messages = self._historic_prompt_messages + + # organize current assistant messages + agent_scratchpad = self._agent_scratchpad + if not agent_scratchpad: + assistant_messages = [] + else: + assistant_message = AssistantPromptMessage(content='') + for unit in agent_scratchpad: + if unit.is_final(): + assistant_message.content += f"Final Answer: {unit.agent_response}" + else: + assistant_message.content += f"Thought: {unit.thought}\n\n" + if unit.action_str: + assistant_message.content += f"Action: {unit.action_str}\n\n" + if unit.observation: + assistant_message.content += f"Observation: {unit.observation}\n\n" + + assistant_messages = [assistant_message] + + # query messages + query_messages = UserPromptMessage(content=self._query) + + if assistant_messages: + messages = [ + system_message, + *historic_messages, + query_messages, + *assistant_messages, + UserPromptMessage(content='continue') + ] + else: + messages = [system_message, *historic_messages, query_messages] + + # join all messages + return messages \ No newline at end of file diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py new file mode 100644 index 0000000000000..3f0298d5a3639 --- /dev/null +++ b/api/core/agent/cot_completion_agent_runner.py @@ -0,0 +1,69 @@ +import json + +from core.agent.cot_agent_runner import CotAgentRunner +from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage +from core.model_runtime.utils.encoders import jsonable_encoder + + +class CotCompletionAgentRunner(CotAgentRunner): + def _organize_instruction_prompt(self) -> str: + """ + Organize instruction prompt + """ + prompt_entity = self.app_config.agent.prompt + first_prompt = prompt_entity.first_prompt + + system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \ + .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \ + .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools])) + + return system_prompt + + def _organize_historic_prompt(self) -> str: + """ + Organize historic prompt + """ + historic_prompt_messages = self._historic_prompt_messages + historic_prompt = "" + + for message in historic_prompt_messages: + if isinstance(message, UserPromptMessage): + historic_prompt += f"Question: {message.content}\n\n" + elif isinstance(message, AssistantPromptMessage): + historic_prompt += message.content + "\n\n" + + return historic_prompt + + def _organize_prompt_messages(self) -> list[PromptMessage]: + """ + Organize prompt messages + """ + # organize system prompt + system_prompt = self._organize_instruction_prompt() + + # organize historic prompt messages + historic_prompt = self._organize_historic_prompt() + + # organize current assistant messages + agent_scratchpad = self._agent_scratchpad + assistant_prompt = '' + for unit in agent_scratchpad: + if unit.is_final(): + assistant_prompt += f"Final Answer: {unit.agent_response}" + else: + assistant_prompt += f"Thought: {unit.thought}\n\n" + if unit.action_str: + assistant_prompt += f"Action: {unit.action_str}\n\n" + if unit.observation: + assistant_prompt += f"Observation: {unit.observation}\n\n" + + # query messages + query_prompt = f"Question: {self._query}" + + # join all messages + prompt = system_prompt \ + .replace("{{historic_messages}}", historic_prompt) \ + .replace("{{agent_scratchpad}}", assistant_prompt) \ + .replace("{{query}}", query_prompt) + + return [UserPromptMessage(content=prompt)] \ No newline at end of file diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index e7016d6030cc2..5284faa02ecba 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -34,12 +34,29 @@ class Action(BaseModel): action_name: str action_input: Union[dict, str] + def to_dict(self) -> dict: + """ + Convert to dictionary. + """ + return { + 'action': self.action_name, + 'action_input': self.action_input, + } + agent_response: Optional[str] = None thought: Optional[str] = None action_str: Optional[str] = None observation: Optional[str] = None action: Optional[Action] = None + def is_final(self) -> bool: + """ + Check if the scratchpad unit is final. + """ + return self.action is None or ( + 'final' in self.action.action_name.lower() and + 'answer' in self.action.action_name.lower() + ) class AgentEntity(BaseModel): """ diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index e66500d327688..a9b3a80073446 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -1,6 +1,7 @@ import json import logging from collections.abc import Generator +from copy import deepcopy from typing import Any, Union from core.agent.base_agent_runner import BaseAgentRunner @@ -10,21 +11,21 @@ from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, - PromptMessageTool, + PromptMessageContentType, SystemPromptMessage, + TextPromptMessageContent, ToolPromptMessage, UserPromptMessage, ) from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine -from models.model import Conversation, Message, MessageAgentThought +from models.model import Message logger = logging.getLogger(__name__) class FunctionCallAgentRunner(BaseAgentRunner): - def run(self, conversation: Conversation, - message: Message, - query: str, + def run(self, + message: Message, query: str, **kwargs: Any ) -> Generator[LLMResultChunk, None, None]: """ Run FunctionCall agent application @@ -35,40 +36,17 @@ def run(self, conversation: Conversation, prompt_template = app_config.prompt_template.simple_prompt_template or '' prompt_messages = self.history_prompt_messages - prompt_messages = self.organize_prompt_messages( - prompt_template=prompt_template, - query=query, - prompt_messages=prompt_messages - ) + prompt_messages = self._init_system_message(prompt_template, prompt_messages) + prompt_messages = self._organize_user_query(query, prompt_messages) # convert tools into ModelRuntime Tool format - prompt_messages_tools: list[PromptMessageTool] = [] - tool_instances = {} - for tool in app_config.agent.tools if app_config.agent else []: - try: - prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) - except Exception: - # api tool may be deleted - continue - # save tool entity - tool_instances[tool.tool_name] = tool_entity - # save prompt tool - prompt_messages_tools.append(prompt_tool) - - # convert dataset tools into ModelRuntime Tool format - for dataset_tool in self.dataset_tools: - prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool) - # save prompt tool - prompt_messages_tools.append(prompt_tool) - # save tool entity - tool_instances[dataset_tool.identity.name] = dataset_tool + tool_instances, prompt_messages_tools = self._init_prompt_tools() iteration_step = 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 # continue to run until there is not any tool call function_call_state = True - agent_thoughts: list[MessageAgentThought] = [] llm_usage = { 'usage': None } @@ -207,19 +185,25 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): ) ) + assistant_message = AssistantPromptMessage( + content='', + tool_calls=[] + ) if tool_calls: - prompt_messages.append(AssistantPromptMessage( - content='', - name='', - tool_calls=[AssistantPromptMessage.ToolCall( + assistant_message.tool_calls=[ + AssistantPromptMessage.ToolCall( id=tool_call[0], type='function', function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False) ) - ) for tool_call in tool_calls] - )) + ) for tool_call in tool_calls + ] + else: + assistant_message.content = response + + prompt_messages.append(assistant_message) # save thought self.save_agent_thought( @@ -239,12 +223,6 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): final_answer += response + '\n' - # update prompt messages - if response.strip(): - prompt_messages.append(AssistantPromptMessage( - content=response, - )) - # call tools tool_responses = [] for tool_call_id, tool_call_name, tool_call_args in tool_calls: @@ -287,9 +265,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): } tool_responses.append(tool_response) - prompt_messages = self.organize_prompt_messages( - prompt_template=prompt_template, - query=None, + prompt_messages = self._organize_assistant_message( tool_call_id=tool_call_id, tool_call_name=tool_call_name, tool_response=tool_response['tool_response'], @@ -324,6 +300,8 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): iteration_step += 1 + prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) + self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( @@ -386,29 +364,68 @@ def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list return tool_calls - def organize_prompt_messages(self, prompt_template: str, - query: str = None, - tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None, - prompt_messages: list[PromptMessage] = None - ) -> list[PromptMessage]: + def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: """ - Organize prompt messages + Initialize system message """ - - if not prompt_messages: - prompt_messages = [ + if not prompt_messages and prompt_template: + return [ SystemPromptMessage(content=prompt_template), - UserPromptMessage(content=query), ] + + if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: + prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) + + return prompt_messages + + def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + """ + Organize user query + """ + if self.files: + prompt_message_contents = [TextPromptMessageContent(data=query)] + for file_obj in self.files: + prompt_message_contents.append(file_obj.prompt_message_content) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: - if tool_response: - prompt_messages = prompt_messages.copy() - prompt_messages.append( - ToolPromptMessage( - content=tool_response, - tool_call_id=tool_call_id, - name=tool_call_name, - ) + prompt_messages.append(UserPromptMessage(content=query)) + + return prompt_messages + + def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None, + prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + """ + Organize assistant message + """ + prompt_messages = deepcopy(prompt_messages) + + if tool_response is not None: + prompt_messages.append( + ToolPromptMessage( + content=tool_response, + tool_call_id=tool_call_id, + name=tool_call_name, ) + ) + + return prompt_messages + + def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + As for now, gpt supports both fc and vision at the first iteration. + We need to remove the image messages from the prompt messages at the first iteration. + """ + prompt_messages = deepcopy(prompt_messages) + + for prompt_message in prompt_messages: + if isinstance(prompt_message, UserPromptMessage): + if isinstance(prompt_message.content, list): + prompt_message.content = '\n'.join([ + content.data if content.type == PromptMessageContentType.TEXT else + '[image]' if content.type == PromptMessageContentType.IMAGE else + '[file]' + for content in prompt_message.content + ]) return prompt_messages \ No newline at end of file diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py new file mode 100644 index 0000000000000..91ac41143ba1e --- /dev/null +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -0,0 +1,183 @@ +import json +import re +from collections.abc import Generator +from typing import Union + +from core.agent.entities import AgentScratchpadUnit +from core.model_runtime.entities.llm_entities import LLMResultChunk + + +class CotAgentOutputParser: + @classmethod + def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None]) -> \ + Generator[Union[str, AgentScratchpadUnit.Action], None, None]: + def parse_action(json_str): + try: + action = json.loads(json_str) + action_name = None + action_input = None + + for key, value in action.items(): + if 'input' in key.lower(): + action_input = value + else: + action_name = value + + if action_name is not None and action_input is not None: + return AgentScratchpadUnit.Action( + action_name=action_name, + action_input=action_input, + ) + else: + return json_str or '' + except: + return json_str or '' + + def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]: + code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL) + if not code_blocks: + return + for block in code_blocks: + json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE) + yield parse_action(json_text) + + code_block_cache = '' + code_block_delimiter_count = 0 + in_code_block = False + json_cache = '' + json_quote_count = 0 + in_json = False + got_json = False + + action_cache = '' + action_str = 'action:' + action_idx = 0 + + thought_cache = '' + thought_str = 'thought:' + thought_idx = 0 + + for response in llm_response: + response = response.delta.message.content + if not isinstance(response, str): + continue + + # stream + index = 0 + while index < len(response): + steps = 1 + delta = response[index:index+steps] + last_character = response[index-1] if index > 0 else '' + + if delta == '`': + code_block_cache += delta + code_block_delimiter_count += 1 + else: + if not in_code_block: + if code_block_delimiter_count > 0: + yield code_block_cache + code_block_cache = '' + else: + code_block_cache += delta + code_block_delimiter_count = 0 + + if not in_code_block and not in_json: + if delta.lower() == action_str[action_idx] and action_idx == 0: + if last_character not in ['\n', ' ', '']: + index += steps + yield delta + continue + + action_cache += delta + action_idx += 1 + if action_idx == len(action_str): + action_cache = '' + action_idx = 0 + index += steps + continue + elif delta.lower() == action_str[action_idx] and action_idx > 0: + action_cache += delta + action_idx += 1 + if action_idx == len(action_str): + action_cache = '' + action_idx = 0 + index += steps + continue + else: + if action_cache: + yield action_cache + action_cache = '' + action_idx = 0 + + if delta.lower() == thought_str[thought_idx] and thought_idx == 0: + if last_character not in ['\n', ' ', '']: + index += steps + yield delta + continue + + thought_cache += delta + thought_idx += 1 + if thought_idx == len(thought_str): + thought_cache = '' + thought_idx = 0 + index += steps + continue + elif delta.lower() == thought_str[thought_idx] and thought_idx > 0: + thought_cache += delta + thought_idx += 1 + if thought_idx == len(thought_str): + thought_cache = '' + thought_idx = 0 + index += steps + continue + else: + if thought_cache: + yield thought_cache + thought_cache = '' + thought_idx = 0 + + if code_block_delimiter_count == 3: + if in_code_block: + yield from extra_json_from_code_block(code_block_cache) + code_block_cache = '' + + in_code_block = not in_code_block + code_block_delimiter_count = 0 + + if not in_code_block: + # handle single json + if delta == '{': + json_quote_count += 1 + in_json = True + json_cache += delta + elif delta == '}': + json_cache += delta + if json_quote_count > 0: + json_quote_count -= 1 + if json_quote_count == 0: + in_json = False + got_json = True + index += steps + continue + else: + if in_json: + json_cache += delta + + if got_json: + got_json = False + yield parse_action(json_cache) + json_cache = '' + json_quote_count = 0 + in_json = False + + if not in_code_block and not in_json: + yield delta.replace('`', '') + + index += steps + + if code_block_cache: + yield code_block_cache + + if json_cache: + yield parse_action(json_cache) + diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 37e10f4bcfc4b..e5cf585f8283d 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -1,4 +1,5 @@ import logging +import os import threading import uuid from collections.abc import Generator @@ -189,6 +190,8 @@ def _generate_worker(self, flask_app: Flask, logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: + if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: logger.exception("Unknown Error when generating") diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 44db0d4a3396e..9866db12f6f77 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -98,6 +98,7 @@ def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity, ) self._stream_generate_routes = self._get_stream_generate_routes() + self._conversation_name_generate_thread = None def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ @@ -108,6 +109,12 @@ def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStrea db.session.refresh(self._user) db.session.close() + # start generate conversation name thread + self._conversation_name_generate_thread = self._generate_conversation_name( + self._conversation, + self._application_generate_entity.query + ) + generator = self._process_stream_response() if self._stream: return self._to_stream_response(generator) @@ -278,6 +285,9 @@ def _process_stream_response(self) -> Generator[StreamResponse, None, None]: else: continue + if self._conversation_name_generate_thread: + self._conversation_name_generate_thread.join() + def _save_message(self) -> None: """ Save message. diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 632cf4f80ad25..847d31440976c 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -1,4 +1,5 @@ import logging +import os import threading import uuid from collections.abc import Generator @@ -198,6 +199,8 @@ def _generate_worker(self, flask_app: Flask, logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: + if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: logger.exception("Unknown Error when generating") diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 0dc8a1e2184ab..dfa5d4591b39d 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -1,7 +1,8 @@ import logging from typing import cast -from core.agent.cot_agent_runner import CotAgentRunner +from core.agent.cot_chat_agent_runner import CotChatAgentRunner +from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner from core.agent.entities import AgentEntity from core.agent.fc_agent_runner import FunctionCallAgentRunner from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig @@ -11,8 +12,8 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.entities.model_entities import ModelFeature +from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage +from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.moderation.base import ModerationException from core.tools.entities.tool_entities import ToolRuntimeVariablePool @@ -207,48 +208,40 @@ def run(self, application_generate_entity: AgentChatAppGenerateEntity, # start agent runner if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: - assistant_cot_runner = CotAgentRunner( - tenant_id=app_config.tenant_id, - application_generate_entity=application_generate_entity, - app_config=app_config, - model_config=application_generate_entity.model_config, - config=agent_entity, - queue_manager=queue_manager, - message=message, - user_id=application_generate_entity.user_id, - memory=memory, - prompt_messages=prompt_message, - variables_pool=tool_variables, - db_variables=tool_conversation_variables, - model_instance=model_instance - ) - invoke_result = assistant_cot_runner.run( - conversation=conversation, - message=message, - query=query, - inputs=inputs, - ) + # check LLM mode + if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: + runner_cls = CotChatAgentRunner + elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION.value: + runner_cls = CotCompletionAgentRunner + else: + raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}") elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING: - assistant_fc_runner = FunctionCallAgentRunner( - tenant_id=app_config.tenant_id, - application_generate_entity=application_generate_entity, - app_config=app_config, - model_config=application_generate_entity.model_config, - config=agent_entity, - queue_manager=queue_manager, - message=message, - user_id=application_generate_entity.user_id, - memory=memory, - prompt_messages=prompt_message, - variables_pool=tool_variables, - db_variables=tool_conversation_variables, - model_instance=model_instance - ) - invoke_result = assistant_fc_runner.run( - conversation=conversation, - message=message, - query=query, - ) + runner_cls = FunctionCallAgentRunner + else: + raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}") + + runner = runner_cls( + tenant_id=app_config.tenant_id, + application_generate_entity=application_generate_entity, + conversation=conversation, + app_config=app_config, + model_config=application_generate_entity.model_config, + config=agent_entity, + queue_manager=queue_manager, + message=message, + user_id=application_generate_entity.user_id, + memory=memory, + prompt_messages=prompt_message, + variables_pool=tool_variables, + db_variables=tool_conversation_variables, + model_instance=model_instance + ) + + invoke_result = runner.run( + message=message, + query=query, + inputs=inputs, + ) # handle invoke result self._handle_invoke_result( diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 6bf309ca1b50b..e67901cca85cd 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -1,4 +1,5 @@ import logging +import os import threading import uuid from collections.abc import Generator @@ -195,6 +196,8 @@ def _generate_worker(self, flask_app: Flask, logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: + if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: logger.exception("Unknown Error when generating") diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index d51f3db5409ee..ba2095076fa40 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -156,6 +156,8 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, dataset_retrieval = DatasetRetrieval() context = dataset_retrieval.retrieve( + app_id=app_record.id, + user_id=application_generate_entity.user_id, tenant_id=app_record.tenant_id, model_config=application_generate_entity.model_config, config=app_config.dataset, diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 0e83da3dfdf5a..5f93afcad71a9 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -1,4 +1,5 @@ import logging +import os import threading import uuid from collections.abc import Generator @@ -184,6 +185,8 @@ def _generate_worker(self, flask_app: Flask, logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: + if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: logger.exception("Unknown Error when generating") diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 649d73d96180f..40102f89996ba 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -116,6 +116,8 @@ def run(self, application_generate_entity: CompletionAppGenerateEntity, dataset_retrieval = DatasetRetrieval() context = dataset_retrieval.retrieve( + app_id=app_record.id, + user_id=application_generate_entity.user_id, tenant_id=app_record.tenant_id, model_config=application_generate_entity.model_config, config=dataset_config, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 759790e7ee262..a9b038ab51d87 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -1,4 +1,5 @@ import logging +import os import threading import uuid from collections.abc import Generator @@ -137,6 +138,8 @@ def _generate_worker(self, flask_app: Flask, logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: + if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: logger.exception("Unknown Error when generating") diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 4fc9d6abaa6ef..a7dbb4754c38d 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -97,6 +97,8 @@ def __init__(self, application_generate_entity: Union[ ) ) + self._conversation_name_generate_thread = None + def process(self) -> Union[ ChatbotAppBlockingResponse, CompletionAppBlockingResponse, @@ -110,6 +112,13 @@ def process(self) -> Union[ db.session.refresh(self._message) db.session.close() + if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: + # start generate conversation name thread + self._conversation_name_generate_thread = self._generate_conversation_name( + self._conversation, + self._application_generate_entity.query + ) + generator = self._process_stream_response() if self._stream: return self._to_stream_response(generator) @@ -256,6 +265,9 @@ def _process_stream_response(self) -> Generator[StreamResponse, None, None]: else: continue + if self._conversation_name_generate_thread: + self._conversation_name_generate_thread.join() + def _save_message(self) -> None: """ Save message. diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 16eb3d4fc28f2..2848455278a5c 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -1,5 +1,8 @@ +from threading import Thread from typing import Optional, Union +from flask import Flask, current_app + from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, AgentChatAppGenerateEntity, @@ -19,9 +22,10 @@ MessageReplaceStreamResponse, MessageStreamResponse, ) +from core.llm_generator.llm_generator import LLMGenerator from core.tools.tool_file_manager import ToolFileManager from extensions.ext_database import db -from models.model import MessageAnnotation, MessageFile +from models.model import AppMode, Conversation, MessageAnnotation, MessageFile from services.annotation_service import AppAnnotationService @@ -34,6 +38,59 @@ class MessageCycleManage: ] _task_state: Union[EasyUITaskState, AdvancedChatTaskState] + def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]: + """ + Generate conversation name. + :param conversation: conversation + :param query: query + :return: thread + """ + is_first_message = self._application_generate_entity.conversation_id is None + extras = self._application_generate_entity.extras + auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True) + + if auto_generate_conversation_name and is_first_message: + # start generate thread + thread = Thread(target=self._generate_conversation_name_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'conversation_id': conversation.id, + 'query': query + }) + + thread.start() + + return thread + + return None + + def _generate_conversation_name_worker(self, + flask_app: Flask, + conversation_id: str, + query: str): + with flask_app.app_context(): + # get conversation and message + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id) + .first() + ) + + if conversation.mode != AppMode.COMPLETION.value: + app_model = conversation.app + if not app_model: + return + + # generate conversation name + try: + name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query) + conversation.name = name + except: + pass + + db.session.merge(conversation) + db.session.commit() + db.session.close() + def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: """ Handle annotation reply. diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 49c79a5c052f8..48ff34fef9ede 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -1,6 +1,6 @@ import json import time -from datetime import datetime +from datetime import datetime, timezone from typing import Any, Optional, Union, cast from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity @@ -120,7 +120,7 @@ def _workflow_run_success(self, workflow_run: WorkflowRun, workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps - workflow_run.finished_at = datetime.utcnow() + workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() db.session.refresh(workflow_run) @@ -149,7 +149,7 @@ def _workflow_run_failed(self, workflow_run: WorkflowRun, workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps - workflow_run.finished_at = datetime.utcnow() + workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() db.session.refresh(workflow_run) @@ -223,7 +223,7 @@ def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNode workflow_node_execution.outputs = json.dumps(outputs) if outputs else None workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \ if execution_metadata else None - workflow_node_execution.finished_at = datetime.utcnow() + workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() db.session.refresh(workflow_node_execution) @@ -251,7 +251,7 @@ def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeE workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.error = error workflow_node_execution.elapsed_time = time.perf_counter() - start_at - workflow_node_execution.finished_at = datetime.utcnow() + workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_node_execution.inputs = json.dumps(inputs) if inputs else None workflow_node_execution.process_data = json.dumps(process_data) if process_data else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 2e524466a1138..af9db39890eeb 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -1,12 +1,32 @@ import os -from typing import Any, Optional, Union +from typing import Any, Optional, TextIO, Union -from langchain.callbacks.base import BaseCallbackHandler -from langchain.input import print_text from pydantic import BaseModel +_TEXT_COLOR_MAPPING = { + "blue": "36;1", + "yellow": "33;1", + "pink": "38;5;200", + "green": "32;1", + "red": "31;1", +} -class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel): +def get_colored_text(text: str, color: str) -> str: + """Get colored text.""" + color_str = _TEXT_COLOR_MAPPING[color] + return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" + + +def print_text( + text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None +) -> None: + """Print text with highlighting and no end characters.""" + text_to_print = get_colored_text(text, color) if color else text + print(text_to_print, end=end, file=file) + if file: + file.flush() # ensure all printed content are written to file + +class DifyAgentCallbackHandler(BaseModel): """Callback Handler that prints to std out.""" color: Optional[str] = '' current_loop = 1 diff --git a/api/core/docstore/dataset_docstore.py b/api/core/docstore/dataset_docstore.py index 9a051fd4cb68d..7567493b9f91e 100644 --- a/api/core/docstore/dataset_docstore.py +++ b/api/core/docstore/dataset_docstore.py @@ -84,7 +84,7 @@ def add_documents( if not isinstance(doc, Document): raise ValueError("doc must be a Document") - segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False) + segment_document = self.get_document_segment(doc_id=doc.metadata['doc_id']) # NOTE: doc could already exist in the store, but we overwrite it if not allow_update and segment_document: diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index 4156368e562c2..b7e0cc0c2b2ae 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -41,7 +41,8 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: embedding_queue_embeddings = [] try: model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) - model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials) + model_schema = model_type_instance.get_model_schema(self._model_instance.model, + self._model_instance.credentials) max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \ if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1 for i in range(0, len(embedding_queue_texts), max_chunks): @@ -61,17 +62,20 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: except Exception as e: logging.exception('Failed transform embedding: ', e) cache_embeddings = [] - for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): - text_embeddings[i] = embedding - hash = helper.generate_text_hash(texts[i]) - if hash not in cache_embeddings: - embedding_cache = Embedding(model_name=self._model_instance.model, - hash=hash, - provider_name=self._model_instance.provider) - embedding_cache.set_embedding(embedding) - db.session.add(embedding_cache) - cache_embeddings.append(hash) - db.session.commit() + try: + for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): + text_embeddings[i] = embedding + hash = helper.generate_text_hash(texts[i]) + if hash not in cache_embeddings: + embedding_cache = Embedding(model_name=self._model_instance.model, + hash=hash, + provider_name=self._model_instance.provider) + embedding_cache.set_embedding(embedding) + db.session.add(embedding_cache) + cache_embeddings.append(hash) + db.session.commit() + except IntegrityError: + db.session.rollback() except Exception as ex: db.session.rollback() logger.error('Failed to embed documents: ', ex) diff --git a/api/core/entities/message_entities.py b/api/core/entities/message_entities.py index 6f767aafc7159..b9406e24c4a98 100644 --- a/api/core/entities/message_entities.py +++ b/api/core/entities/message_entities.py @@ -1,19 +1,8 @@ import enum -from typing import Any, cast +from typing import Any -from langchain.schema import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage from pydantic import BaseModel -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, -) - class PromptMessageFileType(enum.Enum): IMAGE = 'image' @@ -38,98 +27,3 @@ class DETAIL(enum.Enum): type: PromptMessageFileType = PromptMessageFileType.IMAGE detail: DETAIL = DETAIL.LOW - - -class LCHumanMessageWithFiles(HumanMessage): - # content: Union[str, list[Union[str, Dict]]] - content: str - files: list[PromptMessageFile] - - -def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMessage]: - prompt_messages = [] - for message in messages: - if isinstance(message, HumanMessage): - if isinstance(message, LCHumanMessageWithFiles): - file_prompt_message_contents = [] - for file in message.files: - if file.type == PromptMessageFileType.IMAGE: - file = cast(ImagePromptMessageFile, file) - file_prompt_message_contents.append(ImagePromptMessageContent( - data=file.data, - detail=ImagePromptMessageContent.DETAIL.HIGH - if file.detail.value == "high" else ImagePromptMessageContent.DETAIL.LOW - )) - - prompt_message_contents = [TextPromptMessageContent(data=message.content)] - prompt_message_contents.extend(file_prompt_message_contents) - - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: - prompt_messages.append(UserPromptMessage(content=message.content)) - elif isinstance(message, AIMessage): - message_kwargs = { - 'content': message.content - } - - if 'function_call' in message.additional_kwargs: - message_kwargs['tool_calls'] = [ - AssistantPromptMessage.ToolCall( - id=message.additional_kwargs['function_call']['id'], - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=message.additional_kwargs['function_call']['name'], - arguments=message.additional_kwargs['function_call']['arguments'] - ) - ) - ] - - prompt_messages.append(AssistantPromptMessage(**message_kwargs)) - elif isinstance(message, SystemMessage): - prompt_messages.append(SystemPromptMessage(content=message.content)) - elif isinstance(message, FunctionMessage): - prompt_messages.append(ToolPromptMessage(content=message.content, tool_call_id=message.name)) - - return prompt_messages - - -def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list[BaseMessage]: - messages = [] - for prompt_message in prompt_messages: - if isinstance(prompt_message, UserPromptMessage): - if isinstance(prompt_message.content, str): - messages.append(HumanMessage(content=prompt_message.content)) - else: - message_contents = [] - for content in prompt_message.content: - if isinstance(content, TextPromptMessageContent): - message_contents.append(content.data) - elif isinstance(content, ImagePromptMessageContent): - message_contents.append({ - 'type': 'image', - 'data': content.data, - 'detail': content.detail.value - }) - - messages.append(HumanMessage(content=message_contents)) - elif isinstance(prompt_message, AssistantPromptMessage): - message_kwargs = { - 'content': prompt_message.content - } - - if prompt_message.tool_calls: - message_kwargs['additional_kwargs'] = { - 'function_call': { - 'id': prompt_message.tool_calls[0].id, - 'name': prompt_message.tool_calls[0].function.name, - 'arguments': prompt_message.tool_calls[0].function.arguments - } - } - - messages.append(AIMessage(**message_kwargs)) - elif isinstance(prompt_message, SystemPromptMessage): - messages.append(SystemMessage(content=prompt_message.content)) - elif isinstance(prompt_message, ToolPromptMessage): - messages.append(FunctionMessage(name=prompt_message.tool_call_id, content=prompt_message.content)) - - return messages diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index b83ae0c8e766f..303034693deb1 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -203,7 +203,7 @@ def add_or_update_custom_credentials(self, credentials: dict) -> None: if provider_record: provider_record.encrypted_config = json.dumps(credentials) provider_record.is_valid = True - provider_record.updated_at = datetime.datetime.utcnow() + provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() else: provider_record = Provider( @@ -351,7 +351,7 @@ def add_or_update_custom_model_credentials(self, model_type: ModelType, model: s if provider_model_record: provider_model_record.encrypted_config = json.dumps(credentials) provider_model_record.is_valid = True - provider_model_record.updated_at = datetime.datetime.utcnow() + provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() else: provider_model_record = ProviderModel( diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 9a9bc1a75c48c..3221bbe59e9f8 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -1,17 +1,17 @@ -from os import environ from typing import Literal, Optional from httpx import post from pydantic import BaseModel from yarl import URL +from config import get_env from core.helper.code_executor.javascript_transformer import NodeJsTemplateTransformer from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer from core.helper.code_executor.python_transformer import PythonTemplateTransformer # Code Executor -CODE_EXECUTION_ENDPOINT = environ.get('CODE_EXECUTION_ENDPOINT', '') -CODE_EXECUTION_API_KEY = environ.get('CODE_EXECUTION_API_KEY', '') +CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT') +CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY') CODE_EXECUTION_TIMEOUT= (10, 60) @@ -27,6 +27,7 @@ class Data(BaseModel): message: str data: Data + class CodeExecutor: @classmethod def execute_code(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict: diff --git a/api/core/helper/code_executor/javascript_transformer.py b/api/core/helper/code_executor/javascript_transformer.py index 62b7a66468fad..29b8e06e86d4e 100644 --- a/api/core/helper/code_executor/javascript_transformer.py +++ b/api/core/helper/code_executor/javascript_transformer.py @@ -29,16 +29,16 @@ def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]: :param inputs: inputs :return: """ - + # transform inputs to json string - inputs_str = json.dumps(inputs, indent=4) + inputs_str = json.dumps(inputs, indent=4, ensure_ascii=False) # replace code and inputs runner = NODEJS_RUNNER.replace('{{code}}', code) runner = runner.replace('{{inputs}}', inputs_str) return runner, NODEJS_PRELOAD - + @classmethod def transform_response(cls, response: str) -> dict: """ diff --git a/api/core/helper/code_executor/jina2_transformer.py b/api/core/helper/code_executor/jina2_transformer.py index 047d851423e5c..d7b46b0e25f56 100644 --- a/api/core/helper/code_executor/jina2_transformer.py +++ b/api/core/helper/code_executor/jina2_transformer.py @@ -62,10 +62,10 @@ def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]: # transform jinja2 template to python code runner = PYTHON_RUNNER.replace('{{code}}', code) - runner = runner.replace('{{inputs}}', json.dumps(inputs, indent=4)) + runner = runner.replace('{{inputs}}', json.dumps(inputs, indent=4, ensure_ascii=False)) return runner, JINJA2_PRELOAD - + @classmethod def transform_response(cls, response: str) -> dict: """ @@ -81,4 +81,4 @@ def transform_response(cls, response: str) -> dict: return { 'result': result - } \ No newline at end of file + } diff --git a/api/core/helper/code_executor/python_transformer.py b/api/core/helper/code_executor/python_transformer.py index 35edffdd79f91..ca758c1efa1d1 100644 --- a/api/core/helper/code_executor/python_transformer.py +++ b/api/core/helper/code_executor/python_transformer.py @@ -34,7 +34,7 @@ def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]: """ # transform inputs to json string - inputs_str = json.dumps(inputs, indent=4) + inputs_str = json.dumps(inputs, indent=4, ensure_ascii=False) # replace code and inputs runner = PYTHON_RUNNER.replace('{{code}}', code) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 07975d25545b5..51e77393d660e 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -19,6 +19,7 @@ from core.model_runtime.entities.model_entities import ModelType, PriceType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -80,7 +81,7 @@ def run(self, dataset_documents: list[DatasetDocument]): except ProviderTokenNotInitError as e: dataset_document.indexing_status = 'error' dataset_document.error = str(e.description) - dataset_document.stopped_at = datetime.datetime.utcnow() + dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() except ObjectDeletedError: logging.warning('Document deleted, document id: {}'.format(dataset_document.id)) @@ -88,7 +89,7 @@ def run(self, dataset_documents: list[DatasetDocument]): logging.exception("consume document failed") dataset_document.indexing_status = 'error' dataset_document.error = str(e) - dataset_document.stopped_at = datetime.datetime.utcnow() + dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() def run_in_splitting_status(self, dataset_document: DatasetDocument): @@ -139,13 +140,13 @@ def run_in_splitting_status(self, dataset_document: DatasetDocument): except ProviderTokenNotInitError as e: dataset_document.indexing_status = 'error' dataset_document.error = str(e.description) - dataset_document.stopped_at = datetime.datetime.utcnow() + dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() except Exception as e: logging.exception("consume document failed") dataset_document.indexing_status = 'error' dataset_document.error = str(e) - dataset_document.stopped_at = datetime.datetime.utcnow() + dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() def run_in_indexing_status(self, dataset_document: DatasetDocument): @@ -201,13 +202,13 @@ def run_in_indexing_status(self, dataset_document: DatasetDocument): except ProviderTokenNotInitError as e: dataset_document.indexing_status = 'error' dataset_document.error = str(e.description) - dataset_document.stopped_at = datetime.datetime.utcnow() + dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() except Exception as e: logging.exception("consume document failed") dataset_document.indexing_status = 'error' dataset_document.error = str(e) - dataset_document.stopped_at = datetime.datetime.utcnow() + dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict, @@ -381,7 +382,7 @@ def _extract(self, index_processor: BaseIndexProcessor, dataset_document: Datase after_indexing_status="splitting", extra_update_params={ DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]), - DatasetDocument.parsing_completed_at: datetime.datetime.utcnow() + DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) } ) @@ -466,7 +467,7 @@ def _step_split(self, text_docs: list[Document], splitter: TextSplitter, doc_store.add_documents(documents) # update document status to indexing - cur_time = datetime.datetime.utcnow() + cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) self._update_document_index_status( document_id=dataset_document.id, after_indexing_status="indexing", @@ -481,7 +482,7 @@ def _step_split(self, text_docs: list[Document], splitter: TextSplitter, dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.utcnow() + DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) } ) @@ -657,18 +658,25 @@ def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset, if embedding_model_instance: embedding_model_type_instance = embedding_model_instance.model_type_instance embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) - with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: - futures = [] - for i in range(0, len(documents), chunk_size): - chunk_documents = documents[i:i + chunk_size] - futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor, - chunk_documents, dataset, - dataset_document, embedding_model_instance, - embedding_model_type_instance)) - - for future in futures: - tokens += future.result() - + # create keyword index + create_keyword_thread = threading.Thread(target=self._process_keyword_index, + args=(current_app._get_current_object(), + dataset.id, dataset_document.id, documents)) + create_keyword_thread.start() + if dataset.indexing_technique == 'high_quality': + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + futures = [] + for i in range(0, len(documents), chunk_size): + chunk_documents = documents[i:i + chunk_size] + futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor, + chunk_documents, dataset, + dataset_document, embedding_model_instance, + embedding_model_type_instance)) + + for future in futures: + tokens += future.result() + + create_keyword_thread.join() indexing_end_at = time.perf_counter() # update document status to completed @@ -677,11 +685,32 @@ def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset, after_indexing_status="completed", extra_update_params={ DatasetDocument.tokens: tokens, - DatasetDocument.completed_at: datetime.datetime.utcnow(), + DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, } ) + def _process_keyword_index(self, flask_app, dataset_id, document_id, documents): + with flask_app.app_context(): + dataset = Dataset.query.filter_by(id=dataset_id).first() + if not dataset: + raise ValueError("no dataset found") + keyword = Keyword(dataset) + keyword.create(documents) + if dataset.indexing_technique != 'high_quality': + document_ids = [document.metadata['doc_id'] for document in documents] + db.session.query(DocumentSegment).filter( + DocumentSegment.document_id == document_id, + DocumentSegment.index_node_id.in_(document_ids), + DocumentSegment.status == "indexing" + ).update({ + DocumentSegment.status: "completed", + DocumentSegment.enabled: True, + DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + }) + + db.session.commit() + def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance, embedding_model_type_instance): with flask_app.app_context(): @@ -700,7 +729,7 @@ def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, d ) # load index - index_processor.load(dataset, chunk_documents) + index_processor.load(dataset, chunk_documents, with_keywords=False) document_ids = [document.metadata['doc_id'] for document in chunk_documents] db.session.query(DocumentSegment).filter( @@ -710,7 +739,7 @@ def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, d ).update({ DocumentSegment.status: "completed", DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.utcnow() + DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) }) db.session.commit() @@ -809,7 +838,7 @@ def _load_segments(self, dataset, dataset_document, documents): doc_store.add_documents(documents) # update document status to indexing - cur_time = datetime.datetime.utcnow() + cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) self._update_document_index_status( document_id=dataset_document.id, after_indexing_status="indexing", @@ -824,7 +853,7 @@ def _load_segments(self, dataset, dataset_document, documents): dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.utcnow() + DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) } ) pass diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 2fc60daab4529..14de8649c637e 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -1,8 +1,7 @@ import json import logging -from langchain.schema import OutputParserException - +from core.llm_generator.output_parser.errors import OutputParserException from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.llm_generator.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT diff --git a/api/core/llm_generator/output_parser/errors.py b/api/core/llm_generator/output_parser/errors.py new file mode 100644 index 0000000000000..6a60f8de80372 --- /dev/null +++ b/api/core/llm_generator/output_parser/errors.py @@ -0,0 +1,2 @@ +class OutputParserException(Exception): + pass diff --git a/api/core/llm_generator/output_parser/rule_config_generator.py b/api/core/llm_generator/output_parser/rule_config_generator.py index b95653f69c6ee..f6d4bcf11ad64 100644 --- a/api/core/llm_generator/output_parser/rule_config_generator.py +++ b/api/core/llm_generator/output_parser/rule_config_generator.py @@ -1,12 +1,11 @@ from typing import Any -from langchain.schema import BaseOutputParser, OutputParserException - +from core.llm_generator.output_parser.errors import OutputParserException from core.llm_generator.prompts import RULE_CONFIG_GENERATE_TEMPLATE from libs.json_in_md_parser import parse_and_check_json_markdown -class RuleConfigGeneratorOutputParser(BaseOutputParser): +class RuleConfigGeneratorOutputParser: def get_format_instructions(self) -> str: return RULE_CONFIG_GENERATE_TEMPLATE diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index ad30bcfa079b3..3f046c68fceaf 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -2,12 +2,10 @@ import re from typing import Any -from langchain.schema import BaseOutputParser - from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT -class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser): +class SuggestedQuestionsAfterAnswerOutputParser: def get_format_instructions(self) -> str: return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 252b5f1cbad49..cd0b2508d407f 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -3,6 +3,7 @@ from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, + ImagePromptMessageContent, PromptMessage, PromptMessageRole, TextPromptMessageContent, @@ -124,7 +125,17 @@ def get_history_prompt_text(self, human_prefix: str = "Human", else: continue - message = f"{role}: {m.content}" - string_messages.append(message) + if isinstance(m.content, list): + inner_msg = "" + for content in m.content: + if isinstance(content, TextPromptMessageContent): + inner_msg += f"{content.data}\n" + elif isinstance(content, ImagePromptMessageContent): + inner_msg += "[image]\n" + + string_messages.append(f"{role}: {inner_msg.strip()}") + else: + message = f"{role}: {m.content}" + string_messages.append(message) return "\n".join(string_messages) \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml index c06f122984939..62cae69dbede4 100644 --- a/api/core/model_runtime/model_providers/_position.yaml +++ b/api/core/model_runtime/model_providers/_position.yaml @@ -1,3 +1,5 @@ +- modelhub +- openai_api_compatible - openai - anthropic - azure_openai @@ -26,4 +28,3 @@ - yi - openllm - localai -- openai_api_compatible diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml index 792d051d94901..828698acc7bd8 100644 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml +++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml @@ -99,6 +99,12 @@ model_credential_schema: show_on: - variable: __model_type value: llm + - label: + en_US: gpt-4-turbo-2024-04-09 + value: gpt-4-turbo-2024-04-09 + show_on: + - variable: __model_type + value: llm - label: en_US: gpt-4-0125-preview value: gpt-4-0125-preview diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index 4b89adaa49beb..eb6d985f230b8 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -343,8 +343,12 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == '') and \ - delta.delta.function_call is None: + # Handling exceptions when content filters' streaming mode is set to asynchronous modified filter + if delta.delta is None or ( + delta.finish_reason is None + and (delta.delta.content is None or delta.delta.content == '') + and delta.delta.function_call is None + ): continue # assistant_message_tool_calls = delta.delta.tool_calls diff --git a/api/core/model_runtime/model_providers/bedrock/bedrock.yaml b/api/core/model_runtime/model_providers/bedrock/bedrock.yaml index 35374c69baaa6..ea628839822bc 100644 --- a/api/core/model_runtime/model_providers/bedrock/bedrock.yaml +++ b/api/core/model_runtime/model_providers/bedrock/bedrock.yaml @@ -15,6 +15,7 @@ help: en_US: https://console.aws.amazon.com/ supported_model_types: - llm + - text-embedding configurate_methods: - predefined-model provider_credential_schema: @@ -74,7 +75,7 @@ provider_credential_schema: label: en_US: Available Model Name zh_Hans: 可用模型名称 - type: secret-input + type: text-input placeholder: en_US: A model you have access to (e.g. amazon.titan-text-lite-v1) for validation. zh_Hans: 为了进行验证,请输入一个您可用的模型名称 (例如:amazon.titan-text-lite-v1) diff --git a/api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-express-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-express-v1.yaml index 64f992b9133fc..543c16d5cd53e 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-express-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-express-v1.yaml @@ -2,8 +2,6 @@ model: amazon.titan-text-express-v1 label: en_US: Titan Text G1 - Express model_type: llm -features: - - agent-thought model_properties: mode: chat context_size: 8192 diff --git a/api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-lite-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-lite-v1.yaml index 69b298b5711e1..2c6151c239048 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-lite-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/amazon.titan-text-lite-v1.yaml @@ -2,8 +2,6 @@ model: amazon.titan-text-lite-v1 label: en_US: Titan Text G1 - Lite model_type: llm -features: - - agent-thought model_properties: mode: chat context_size: 4096 diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v1.yaml index cb2271d4017e4..6a714b1055b2a 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v1.yaml @@ -50,3 +50,4 @@ pricing: output: '0.024' unit: '0.001' currency: USD +deprecated: true diff --git a/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-light-text-v14.yaml b/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-light-text-v14.yaml index 1fad91005857d..74500095511f1 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-light-text-v14.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-light-text-v14.yaml @@ -22,7 +22,7 @@ parameter_rules: min: 0 max: 500 default: 0 - - name: max_tokens_to_sample + - name: max_tokens use_template: max_tokens required: true default: 4096 diff --git a/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-text-v14.yaml b/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-text-v14.yaml index ed775afd7a448..6aea5be170c6e 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-text-v14.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/cohere.command-text-v14.yaml @@ -8,9 +8,9 @@ model_properties: parameter_rules: - name: temperature use_template: temperature - - name: top_p + - name: p use_template: top_p - - name: top_k + - name: k label: zh_Hans: 取样数量 en_US: Top k @@ -19,7 +19,7 @@ parameter_rules: zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 en_US: Only sample from the top K options for each subsequent token. required: false - - name: max_tokens_to_sample + - name: max_tokens use_template: max_tokens required: true default: 4096 diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 0e256999c0ee1..0b0959eaa08f2 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -402,25 +402,25 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :param credentials: model credentials :return: """ - - if "anthropic.claude-3" in model: - try: - self._invoke_claude(model=model, - credentials=credentials, - prompt_messages=[{"role": "user", "content": "ping"}], - model_parameters={}, - stop=None, - stream=False) - - except Exception as ex: - raise CredentialsValidateFailedError(str(ex)) - + required_params = {} + if "anthropic" in model: + required_params = { + "max_tokens": 32, + } + elif "ai21" in model: + # ValidationException: Malformed input request: #/temperature: expected type: Number, found: Null#/maxTokens: expected type: Integer, found: Null#/topP: expected type: Number, found: Null, please reformat your input and try again. + required_params = { + "temperature": 0.7, + "topP": 0.9, + "maxTokens": 32, + } + try: ping_message = UserPromptMessage(content="ping") - self._generate(model=model, + self._invoke(model=model, credentials=credentials, prompt_messages=[ping_message], - model_parameters={}, + model_parameters=required_params, stream=False) except ClientError as ex: @@ -503,7 +503,7 @@ def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage if model_prefix == "amazon": payload["textGenerationConfig"] = { **model_parameters } - payload["textGenerationConfig"]["stopSequences"] = ["User:"] + (stop if stop else []) + payload["textGenerationConfig"]["stopSequences"] = ["User:"] payload["inputText"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) @@ -513,10 +513,6 @@ def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage payload["maxTokens"] = model_parameters.get("maxTokens") payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) - # jurassic models only support a single stop sequence - if stop: - payload["stopSequences"] = stop[0] - if model_parameters.get("presencePenalty"): payload["presencePenalty"] = {model_parameters.get("presencePenalty")} if model_parameters.get("frequencyPenalty"): diff --git a/api/core/rag/retrieval/agent/__init__.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/__init__.py similarity index 100% rename from api/core/rag/retrieval/agent/__init__.py rename to api/core/model_runtime/model_providers/bedrock/text_embedding/__init__.py diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml b/api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml new file mode 100644 index 0000000000000..5419ff530b9b8 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml @@ -0,0 +1,3 @@ +- amazon.titan-embed-text-v1 +- cohere.embed-english-v3 +- cohere.embed-multilingual-v3 diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v1.yaml b/api/core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v1.yaml new file mode 100644 index 0000000000000..6a1cf75be1a8c --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v1.yaml @@ -0,0 +1,8 @@ +model: amazon.titan-embed-text-v1 +model_type: text-embedding +model_properties: + context_size: 8192 +pricing: + input: '0.0001' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/cohere.embed-english-v3.yaml b/api/core/model_runtime/model_providers/bedrock/text_embedding/cohere.embed-english-v3.yaml new file mode 100644 index 0000000000000..d49aa2a99c3a9 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/cohere.embed-english-v3.yaml @@ -0,0 +1,8 @@ +model: cohere.embed-english-v3 +model_type: text-embedding +model_properties: + context_size: 512 +pricing: + input: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/cohere.embed-multilingual-v3.yaml b/api/core/model_runtime/model_providers/bedrock/text_embedding/cohere.embed-multilingual-v3.yaml new file mode 100644 index 0000000000000..63bab59d2cea6 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/cohere.embed-multilingual-v3.yaml @@ -0,0 +1,8 @@ +model: cohere.embed-multilingual-v3 +model_type: text-embedding +model_properties: + context_size: 512 +pricing: + input: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py new file mode 100644 index 0000000000000..69436cd737afb --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -0,0 +1,234 @@ +import json +import logging +import time +from typing import Optional + +import boto3 +from botocore.config import Config +from botocore.exceptions import ( + ClientError, + EndpointConnectionError, + NoRegionError, + ServiceNotInRegionError, + UnknownServiceError, +) + +from core.model_runtime.entities.model_entities import PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel + +logger = logging.getLogger(__name__) + +class BedrockTextEmbeddingModel(TextEmbeddingModel): + + + def _invoke(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + client_config = Config( + region_name=credentials["aws_region"] + ) + + bedrock_runtime = boto3.client( + service_name='bedrock-runtime', + config=client_config, + aws_access_key_id=credentials["aws_access_key_id"], + aws_secret_access_key=credentials["aws_secret_access_key"] + ) + + embeddings = [] + token_usage = 0 + + model_prefix = model.split('.')[0] + + if model_prefix == "amazon" : + for text in texts: + body = { + "inputText": text, + } + response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) + embeddings.extend([response_body.get('embedding')]) + token_usage += response_body.get('inputTextTokenCount') + logger.warning(f'Total Tokens: {token_usage}') + result = TextEmbeddingResult( + model=model, + embeddings=embeddings, + usage=self._calc_response_usage( + model=model, + credentials=credentials, + tokens=token_usage + ) + ) + return result + + if model_prefix == "cohere" : + input_type = 'search_document' if len(texts) > 1 else 'search_query' + for text in texts: + body = { + "texts": [text], + "input_type": input_type, + } + response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) + embeddings.extend(response_body.get('embeddings')) + token_usage += len(text) + result = TextEmbeddingResult( + model=model, + embeddings=embeddings, + usage=self._calc_response_usage( + model=model, + credentials=credentials, + tokens=token_usage + ) + ) + return result + + #others + raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") + + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + num_tokens = 0 + for text in texts: + num_tokens += self._get_num_tokens_by_gpt2(text) + return num_tokens + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller + The value is the md = genai.GenerativeModel(model)error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke emd = genai.GenerativeModel(model)rror mapping + """ + return { + InvokeConnectionError: [], + InvokeServerUnavailableError: [], + InvokeRateLimitError: [], + InvokeAuthorizationError: [], + InvokeBadRequestError: [] + } + + def _create_payload(self, model_prefix: str, texts: list[str], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): + """ + Create payload for bedrock api call depending on model provider + """ + payload = dict() + + if model_prefix == "amazon": + payload['inputText'] = texts + + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage + + def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]: + """ + Map client error to invoke error + + :param error_code: error code + :param error_msg: error message + :return: invoke error + """ + + if error_code == "AccessDeniedException": + return InvokeAuthorizationError(error_msg) + elif error_code in ["ResourceNotFoundException", "ValidationException"]: + return InvokeBadRequestError(error_msg) + elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: + return InvokeRateLimitError(error_msg) + elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]: + return InvokeServerUnavailableError(error_msg) + elif error_code == "ModelStreamErrorException": + return InvokeConnectionError(error_msg) + + return InvokeError(error_msg) + + + def _invoke_bedrock_embedding(self, model: str, bedrock_runtime, body: dict, ): + accept = 'application/json' + content_type = 'application/json' + try: + response = bedrock_runtime.invoke_model( + body=json.dumps(body), + modelId=model, + accept=accept, + contentType=content_type + ) + response_body = json.loads(response.get('body').read().decode('utf-8')) + return response_body + except ClientError as ex: + error_code = ex.response['Error']['Code'] + full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" + raise self._map_client_to_invoke_error(error_code, full_error_msg) + + except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: + raise InvokeConnectionError(str(ex)) + + except UnknownServiceError as ex: + raise InvokeServerUnavailableError(str(ex)) + + except Exception as ex: + raise InvokeError(str(ex)) diff --git a/api/core/model_runtime/model_providers/cohere/llm/_position.yaml b/api/core/model_runtime/model_providers/cohere/llm/_position.yaml index 367117c9e81f5..42d06f49a2b39 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/cohere/llm/_position.yaml @@ -1,3 +1,5 @@ +- command-r +- command-r-plus - command-chat - command-light-chat - command-nightly-chat diff --git a/api/core/model_runtime/model_providers/cohere/llm/command-chat.yaml b/api/core/model_runtime/model_providers/cohere/llm/command-chat.yaml index 4bcfae6e5d08a..5f233f35ceeeb 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/command-chat.yaml +++ b/api/core/model_runtime/model_providers/cohere/llm/command-chat.yaml @@ -31,7 +31,7 @@ parameter_rules: max: 500 - name: max_tokens use_template: max_tokens - default: 256 + default: 1024 max: 4096 - name: preamble_override label: diff --git a/api/core/model_runtime/model_providers/cohere/llm/command-light-chat.yaml b/api/core/model_runtime/model_providers/cohere/llm/command-light-chat.yaml index 8d8075967c775..b5f00487703a0 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/command-light-chat.yaml +++ b/api/core/model_runtime/model_providers/cohere/llm/command-light-chat.yaml @@ -31,7 +31,7 @@ parameter_rules: max: 500 - name: max_tokens use_template: max_tokens - default: 256 + default: 1024 max: 4096 - name: preamble_override label: diff --git a/api/core/model_runtime/model_providers/cohere/llm/command-light-nightly-chat.yaml b/api/core/model_runtime/model_providers/cohere/llm/command-light-nightly-chat.yaml index 4b6b66951ee59..1c96b24030224 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/command-light-nightly-chat.yaml +++ b/api/core/model_runtime/model_providers/cohere/llm/command-light-nightly-chat.yaml @@ -31,7 +31,7 @@ parameter_rules: max: 500 - name: max_tokens use_template: max_tokens - default: 256 + default: 1024 max: 4096 - name: preamble_override label: diff --git a/api/core/model_runtime/model_providers/cohere/llm/command-light-nightly.yaml b/api/core/model_runtime/model_providers/cohere/llm/command-light-nightly.yaml index 6a76c25019bc0..4616f76689786 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/command-light-nightly.yaml +++ b/api/core/model_runtime/model_providers/cohere/llm/command-light-nightly.yaml @@ -35,7 +35,7 @@ parameter_rules: use_template: frequency_penalty - name: max_tokens use_template: max_tokens - default: 256 + default: 1024 max: 4096 pricing: input: '0.3' diff --git a/api/core/model_runtime/model_providers/cohere/llm/command-light.yaml b/api/core/model_runtime/model_providers/cohere/llm/command-light.yaml index ff9a594b66fbb..161756b3220d2 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/command-light.yaml +++ b/api/core/model_runtime/model_providers/cohere/llm/command-light.yaml @@ -35,7 +35,7 @@ parameter_rules: use_template: frequency_penalty - name: max_tokens use_template: max_tokens - default: 256 + default: 1024 max: 4096 pricing: input: '0.3' diff --git a/api/core/model_runtime/model_providers/cohere/llm/command-nightly-chat.yaml b/api/core/model_runtime/model_providers/cohere/llm/command-nightly-chat.yaml index 811f237c887f4..739e09e72e9e6 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/command-nightly-chat.yaml +++ b/api/core/model_runtime/model_providers/cohere/llm/command-nightly-chat.yaml @@ -31,7 +31,7 @@ parameter_rules: max: 500 - name: max_tokens use_template: max_tokens - default: 256 + default: 1024 max: 4096 - name: preamble_override label: diff --git a/api/core/model_runtime/model_providers/cohere/llm/command-nightly.yaml b/api/core/model_runtime/model_providers/cohere/llm/command-nightly.yaml index 2c99bf7684255..1e025e40c4b45 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/command-nightly.yaml +++ b/api/core/model_runtime/model_providers/cohere/llm/command-nightly.yaml @@ -35,7 +35,7 @@ parameter_rules: use_template: frequency_penalty - name: max_tokens use_template: max_tokens - default: 256 + default: 1024 max: 4096 pricing: input: '1.0' diff --git a/api/core/model_runtime/model_providers/cohere/llm/command-r-plus.yaml b/api/core/model_runtime/model_providers/cohere/llm/command-r-plus.yaml new file mode 100644 index 0000000000000..617e6853ea651 --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/llm/command-r-plus.yaml @@ -0,0 +1,45 @@ +model: command-r-plus +label: + en_US: command-r-plus +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + max: 5.0 + - name: p + use_template: top_p + default: 0.75 + min: 0.01 + max: 0.99 + - name: k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 0 + min: 0 + max: 500 + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 1024 + max: 4096 +pricing: + input: '3' + output: '15' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/cohere/llm/command-r.yaml b/api/core/model_runtime/model_providers/cohere/llm/command-r.yaml new file mode 100644 index 0000000000000..c36680443b7a4 --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/llm/command-r.yaml @@ -0,0 +1,45 @@ +model: command-r +label: + en_US: command-r +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + max: 5.0 + - name: p + use_template: top_p + default: 0.75 + min: 0.01 + max: 0.99 + - name: k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 0 + min: 0 + max: 500 + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 1024 + max: 4096 +pricing: + input: '0.5' + output: '1.5' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/cohere/llm/command.yaml b/api/core/model_runtime/model_providers/cohere/llm/command.yaml index d41c2951fcd6b..0cac7c35ea140 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/command.yaml +++ b/api/core/model_runtime/model_providers/cohere/llm/command.yaml @@ -35,7 +35,7 @@ parameter_rules: use_template: frequency_penalty - name: max_tokens use_template: max_tokens - default: 256 + default: 1024 max: 4096 pricing: input: '1.0' diff --git a/api/core/model_runtime/model_providers/cohere/llm/llm.py b/api/core/model_runtime/model_providers/cohere/llm/llm.py index 50805bce85b60..6ace77b813c90 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/llm.py +++ b/api/core/model_runtime/model_providers/cohere/llm/llm.py @@ -1,20 +1,38 @@ +import json import logging -from collections.abc import Generator +from collections.abc import Generator, Iterator from typing import Optional, Union, cast import cohere -from cohere.responses import Chat, Generations -from cohere.responses.chat import StreamEnd, StreamingChat, StreamTextGeneration -from cohere.responses.generation import StreamingGenerations, StreamingText +from cohere import ( + ChatMessage, + ChatStreamRequestToolResultsItem, + GenerateStreamedResponse, + GenerateStreamedResponse_StreamEnd, + GenerateStreamedResponse_StreamError, + GenerateStreamedResponse_TextGeneration, + Generation, + NonStreamedChatResponse, + StreamedChatResponse, + StreamedChatResponse_StreamEnd, + StreamedChatResponse_TextGeneration, + StreamedChatResponse_ToolCallsGeneration, + Tool, + ToolCall, + ToolParameterDefinitionsValue, +) +from cohere.core import RequestOptions from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageContentType, + PromptMessageRole, PromptMessageTool, SystemPromptMessage, TextPromptMessageContent, + ToolPromptMessage, UserPromptMessage, ) from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType @@ -64,6 +82,7 @@ def _invoke(self, model: str, credentials: dict, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, + tools=tools, stop=stop, stream=stream, user=user @@ -159,19 +178,26 @@ def _generate(self, model: str, credentials: dict, if stop: model_parameters['end_sequences'] = stop - response = client.generate( - prompt=prompt_messages[0].content, - model=model, - stream=stream, - **model_parameters, - ) - if stream: + response = client.generate_stream( + prompt=prompt_messages[0].content, + model=model, + **model_parameters, + request_options=RequestOptions(max_retries=0) + ) + return self._handle_generate_stream_response(model, credentials, response, prompt_messages) + else: + response = client.generate( + prompt=prompt_messages[0].content, + model=model, + **model_parameters, + request_options=RequestOptions(max_retries=0) + ) - return self._handle_generate_response(model, credentials, response, prompt_messages) + return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: Generations, + def _handle_generate_response(self, model: str, credentials: dict, response: Generation, prompt_messages: list[PromptMessage]) \ -> LLMResult: """ @@ -191,8 +217,8 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Gen ) # calculate num tokens - prompt_tokens = response.meta['billed_units']['input_tokens'] - completion_tokens = response.meta['billed_units']['output_tokens'] + prompt_tokens = int(response.meta.billed_units.input_tokens) + completion_tokens = int(response.meta.billed_units.output_tokens) # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) @@ -207,7 +233,7 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Gen return response - def _handle_generate_stream_response(self, model: str, credentials: dict, response: StreamingGenerations, + def _handle_generate_stream_response(self, model: str, credentials: dict, response: Iterator[GenerateStreamedResponse], prompt_messages: list[PromptMessage]) -> Generator: """ Handle llm stream response @@ -220,8 +246,8 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon index = 1 full_assistant_content = '' for chunk in response: - if isinstance(chunk, StreamingText): - chunk = cast(StreamingText, chunk) + if isinstance(chunk, GenerateStreamedResponse_TextGeneration): + chunk = cast(GenerateStreamedResponse_TextGeneration, chunk) text = chunk.text if text is None: @@ -244,10 +270,16 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon ) index += 1 - elif chunk is None: + elif isinstance(chunk, GenerateStreamedResponse_StreamEnd): + chunk = cast(GenerateStreamedResponse_StreamEnd, chunk) + # calculate num tokens - prompt_tokens = response.meta['billed_units']['input_tokens'] - completion_tokens = response.meta['billed_units']['output_tokens'] + prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) + completion_tokens = self._num_tokens_from_messages( + model, + credentials, + [AssistantPromptMessage(content=full_assistant_content)] + ) # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) @@ -258,14 +290,18 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon delta=LLMResultChunkDelta( index=index, message=AssistantPromptMessage(content=''), - finish_reason=response.finish_reason, + finish_reason=chunk.finish_reason, usage=usage ) ) break + elif isinstance(chunk, GenerateStreamedResponse_StreamError): + chunk = cast(GenerateStreamedResponse_StreamError, chunk) + raise InvokeBadRequestError(chunk.err) def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -274,6 +310,7 @@ def _chat_generate(self, model: str, credentials: dict, :param credentials: credentials :param prompt_messages: prompt messages :param model_parameters: model parameters + :param tools: tools for tool calling :param stop: stop words :param stream: is stream response :param user: unique user id @@ -282,32 +319,49 @@ def _chat_generate(self, model: str, credentials: dict, # initialize client client = cohere.Client(credentials.get('api_key')) - if user: - model_parameters['user_name'] = user + if stop: + model_parameters['stop_sequences'] = stop + + if tools: + if len(tools) == 1: + raise ValueError("Cohere tool call requires at least two tools to be specified.") - message, chat_histories = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages) + model_parameters['tools'] = self._convert_tools(tools) + + message, chat_histories, tool_results \ + = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages) + + if tool_results: + model_parameters['tool_results'] = tool_results # chat model real_model = model if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: real_model = model.removesuffix('-chat') - response = client.chat( - message=message, - chat_history=chat_histories, - model=real_model, - stream=stream, - return_preamble=True, - **model_parameters, - ) - if stream: - return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, stop) + response = client.chat_stream( + message=message, + chat_history=chat_histories, + model=real_model, + **model_parameters, + request_options=RequestOptions(max_retries=0) + ) + + return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages) + else: + response = client.chat( + message=message, + chat_history=chat_histories, + model=real_model, + **model_parameters, + request_options=RequestOptions(max_retries=0) + ) - return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop) + return self._handle_chat_generate_response(model, credentials, response, prompt_messages) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat, - prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None) \ + def _handle_chat_generate_response(self, model: str, credentials: dict, response: NonStreamedChatResponse, + prompt_messages: list[PromptMessage]) \ -> LLMResult: """ Handle llm chat response @@ -316,14 +370,27 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response :param credentials: credentials :param response: response :param prompt_messages: prompt messages - :param stop: stop words :return: llm response """ assistant_text = response.text + tool_calls = [] + if response.tool_calls: + for cohere_tool_call in response.tool_calls: + tool_call = AssistantPromptMessage.ToolCall( + id=cohere_tool_call.name, + type='function', + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=cohere_tool_call.name, + arguments=json.dumps(cohere_tool_call.parameters) + ) + ) + tool_calls.append(tool_call) + # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=assistant_text + content=assistant_text, + tool_calls=tool_calls ) # calculate num tokens @@ -333,44 +400,38 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - if stop: - # enforce stop tokens - assistant_text = self.enforce_stop_tokens(assistant_text, stop) - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) - # transform response response = LLMResult( model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, - usage=usage, - system_fingerprint=response.preamble + usage=usage ) return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat, - prompt_messages: list[PromptMessage], - stop: Optional[list[str]] = None) -> Generator: + def _handle_chat_generate_stream_response(self, model: str, credentials: dict, + response: Iterator[StreamedChatResponse], + prompt_messages: list[PromptMessage]) -> Generator: """ Handle llm chat stream response :param model: model name :param response: response :param prompt_messages: prompt messages - :param stop: stop words :return: llm response chunk generator """ - def final_response(full_text: str, index: int, finish_reason: Optional[str] = None, - preamble: Optional[str] = None) -> LLMResultChunk: + def final_response(full_text: str, + tool_calls: list[AssistantPromptMessage.ToolCall], + index: int, + finish_reason: Optional[str] = None) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) full_assistant_prompt_message = AssistantPromptMessage( - content=full_text + content=full_text, + tool_calls=tool_calls ) completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message]) @@ -380,10 +441,9 @@ def final_response(full_text: str, index: int, finish_reason: Optional[str] = No return LLMResultChunk( model=model, prompt_messages=prompt_messages, - system_fingerprint=preamble, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage(content=''), + message=AssistantPromptMessage(content='', tool_calls=tool_calls), finish_reason=finish_reason, usage=usage ) @@ -391,9 +451,10 @@ def final_response(full_text: str, index: int, finish_reason: Optional[str] = No index = 1 full_assistant_content = '' + tool_calls = [] for chunk in response: - if isinstance(chunk, StreamTextGeneration): - chunk = cast(StreamTextGeneration, chunk) + if isinstance(chunk, StreamedChatResponse_TextGeneration): + chunk = cast(StreamedChatResponse_TextGeneration, chunk) text = chunk.text if text is None: @@ -404,12 +465,6 @@ def final_response(full_text: str, index: int, finish_reason: Optional[str] = No content=text ) - # stop - # notice: This logic can only cover few stop scenarios - if stop and text in stop: - yield final_response(full_assistant_content, index, 'stop') - break - full_assistant_content += text yield LLMResultChunk( @@ -422,39 +477,96 @@ def final_response(full_text: str, index: int, finish_reason: Optional[str] = No ) index += 1 - elif isinstance(chunk, StreamEnd): - chunk = cast(StreamEnd, chunk) - yield final_response(full_assistant_content, index, chunk.finish_reason, response.preamble) + elif isinstance(chunk, StreamedChatResponse_ToolCallsGeneration): + chunk = cast(StreamedChatResponse_ToolCallsGeneration, chunk) + if chunk.tool_calls: + for cohere_tool_call in chunk.tool_calls: + tool_call = AssistantPromptMessage.ToolCall( + id=cohere_tool_call.name, + type='function', + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=cohere_tool_call.name, + arguments=json.dumps(cohere_tool_call.parameters) + ) + ) + tool_calls.append(tool_call) + elif isinstance(chunk, StreamedChatResponse_StreamEnd): + chunk = cast(StreamedChatResponse_StreamEnd, chunk) + yield final_response(full_assistant_content, tool_calls, index, chunk.finish_reason) index += 1 def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \ - -> tuple[str, list[dict]]: + -> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]: """ Convert prompt messages to message and chat histories :param prompt_messages: prompt messages :return: """ chat_histories = [] + latest_tool_call_n_outputs = [] for prompt_message in prompt_messages: - chat_histories.append(self._convert_prompt_message_to_dict(prompt_message)) + if prompt_message.role == PromptMessageRole.ASSISTANT: + prompt_message = cast(AssistantPromptMessage, prompt_message) + if prompt_message.tool_calls: + for tool_call in prompt_message.tool_calls: + latest_tool_call_n_outputs.append(ChatStreamRequestToolResultsItem( + call=ToolCall( + name=tool_call.function.name, + parameters=json.loads(tool_call.function.arguments) + ), + outputs=[] + )) + else: + cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message) + if cohere_prompt_message: + chat_histories.append(cohere_prompt_message) + elif prompt_message.role == PromptMessageRole.TOOL: + prompt_message = cast(ToolPromptMessage, prompt_message) + if latest_tool_call_n_outputs: + i = 0 + for tool_call_n_outputs in latest_tool_call_n_outputs: + if tool_call_n_outputs.call.name == prompt_message.tool_call_id: + latest_tool_call_n_outputs[i] = ChatStreamRequestToolResultsItem( + call=ToolCall( + name=tool_call_n_outputs.call.name, + parameters=tool_call_n_outputs.call.parameters + ), + outputs=[{ + "result": prompt_message.content + }] + ) + break + i += 1 + else: + cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message) + if cohere_prompt_message: + chat_histories.append(cohere_prompt_message) + + if latest_tool_call_n_outputs: + new_latest_tool_call_n_outputs = [] + for tool_call_n_outputs in latest_tool_call_n_outputs: + if tool_call_n_outputs.outputs: + new_latest_tool_call_n_outputs.append(tool_call_n_outputs) + + latest_tool_call_n_outputs = new_latest_tool_call_n_outputs # get latest message from chat histories and pop it if len(chat_histories) > 0: latest_message = chat_histories.pop() - message = latest_message['message'] + message = latest_message.message else: raise ValueError('Prompt messages is empty') - return message, chat_histories + return message, chat_histories, latest_tool_call_n_outputs - def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> Optional[ChatMessage]: """ Convert PromptMessage to dict for Cohere model """ if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) if isinstance(message.content, str): - message_dict = {"role": "USER", "message": message.content} + chat_message = ChatMessage(role="USER", message=message.content) else: sub_message_text = '' for message_content in message.content: @@ -462,20 +574,57 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: message_content = cast(TextPromptMessageContent, message_content) sub_message_text += message_content.data - message_dict = {"role": "USER", "message": sub_message_text} + chat_message = ChatMessage(role="USER", message=sub_message_text) elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) - message_dict = {"role": "CHATBOT", "message": message.content} + if not message.content: + return None + chat_message = ChatMessage(role="CHATBOT", message=message.content) elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = {"role": "USER", "message": message.content} + chat_message = ChatMessage(role="USER", message=message.content) + elif isinstance(message, ToolPromptMessage): + return None else: raise ValueError(f"Got unknown type {message}") - if message.name: - message_dict["user_name"] = message.name + return chat_message + + def _convert_tools(self, tools: list[PromptMessageTool]) -> list[Tool]: + """ + Convert tools to Cohere model + """ + cohere_tools = [] + for tool in tools: + properties = tool.parameters['properties'] + required_properties = tool.parameters['required'] + + parameter_definitions = {} + for p_key, p_val in properties.items(): + required = False + if p_key in required_properties: + required = True + + desc = p_val['description'] + if 'enum' in p_val: + desc += (f"; Only accepts one of the following predefined options: " + f"[{', '.join(p_val['enum'])}]") + + parameter_definitions[p_key] = ToolParameterDefinitionsValue( + description=desc, + type=p_val['type'], + required=required + ) - return message_dict + cohere_tool = Tool( + name=tool.name, + description=tool.description, + parameter_definitions=parameter_definitions + ) + + cohere_tools.append(cohere_tool) + + return cohere_tools def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> int: """ @@ -494,12 +643,16 @@ def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> i model=model ) - return response.length + return len(response.tokens) def _num_tokens_from_messages(self, model: str, credentials: dict, messages: list[PromptMessage]) -> int: """Calculate num tokens Cohere model.""" - messages = [self._convert_prompt_message_to_dict(m) for m in messages] - message_strs = [f"{message['role']}: {message['message']}" for message in messages] + calc_messages = [] + for message in messages: + cohere_message = self._convert_prompt_message_to_dict(message) + if cohere_message: + calc_messages.append(cohere_message) + message_strs = [f"{message.role}: {message.message}" for message in calc_messages] message_str = "\n".join(message_strs) real_model = model @@ -565,13 +718,21 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] """ return { InvokeConnectionError: [ - cohere.CohereConnectionError + cohere.errors.service_unavailable_error.ServiceUnavailableError + ], + InvokeServerUnavailableError: [ + cohere.errors.internal_server_error.InternalServerError + ], + InvokeRateLimitError: [ + cohere.errors.too_many_requests_error.TooManyRequestsError + ], + InvokeAuthorizationError: [ + cohere.errors.unauthorized_error.UnauthorizedError, + cohere.errors.forbidden_error.ForbiddenError ], - InvokeServerUnavailableError: [], - InvokeRateLimitError: [], - InvokeAuthorizationError: [], InvokeBadRequestError: [ - cohere.CohereAPIError, - cohere.CohereError, + cohere.core.api_error.ApiError, + cohere.errors.bad_request_error.BadRequestError, + cohere.errors.not_found_error.NotFoundError, ] } diff --git a/api/core/model_runtime/model_providers/cohere/rerank/_position.yaml b/api/core/model_runtime/model_providers/cohere/rerank/_position.yaml new file mode 100644 index 0000000000000..4dd58fc1708b0 --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/rerank/_position.yaml @@ -0,0 +1,4 @@ +- rerank-english-v2.0 +- rerank-english-v3.0 +- rerank-multilingual-v2.0 +- rerank-multilingual-v3.0 diff --git a/api/core/model_runtime/model_providers/cohere/rerank/rerank-english-v3.0.yaml b/api/core/model_runtime/model_providers/cohere/rerank/rerank-english-v3.0.yaml new file mode 100644 index 0000000000000..3779f0b6c25dd --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/rerank/rerank-english-v3.0.yaml @@ -0,0 +1,4 @@ +model: rerank-english-v3.0 +model_type: rerank +model_properties: + context_size: 5120 diff --git a/api/core/model_runtime/model_providers/cohere/rerank/rerank-multilingual-v3.0.yaml b/api/core/model_runtime/model_providers/cohere/rerank/rerank-multilingual-v3.0.yaml new file mode 100644 index 0000000000000..4f6690ba7685b --- /dev/null +++ b/api/core/model_runtime/model_providers/cohere/rerank/rerank-multilingual-v3.0.yaml @@ -0,0 +1,4 @@ +model: rerank-multilingual-v3.0 +model_type: rerank +model_properties: + context_size: 5120 diff --git a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py index 7fee57f670929..4194f27eb94cd 100644 --- a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py @@ -1,6 +1,7 @@ from typing import Optional import cohere +from cohere.core import RequestOptions from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult from core.model_runtime.errors.invoke import ( @@ -44,19 +45,21 @@ def _invoke(self, model: str, credentials: dict, # initialize client client = cohere.Client(credentials.get('api_key')) - results = client.rerank( + response = client.rerank( query=query, documents=docs, model=model, - top_n=top_n + top_n=top_n, + return_documents=True, + request_options=RequestOptions(max_retries=0) ) rerank_documents = [] - for idx, result in enumerate(results): + for idx, result in enumerate(response.results): # format document rerank_document = RerankDocument( index=result.index, - text=result.document['text'], + text=result.document.text, score=result.relevance_score, ) @@ -108,13 +111,21 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] """ return { InvokeConnectionError: [ - cohere.CohereConnectionError, + cohere.errors.service_unavailable_error.ServiceUnavailableError + ], + InvokeServerUnavailableError: [ + cohere.errors.internal_server_error.InternalServerError + ], + InvokeRateLimitError: [ + cohere.errors.too_many_requests_error.TooManyRequestsError + ], + InvokeAuthorizationError: [ + cohere.errors.unauthorized_error.UnauthorizedError, + cohere.errors.forbidden_error.ForbiddenError ], - InvokeServerUnavailableError: [], - InvokeRateLimitError: [], - InvokeAuthorizationError: [], InvokeBadRequestError: [ - cohere.CohereAPIError, - cohere.CohereError, + cohere.core.api_error.ApiError, + cohere.errors.bad_request_error.BadRequestError, + cohere.errors.not_found_error.NotFoundError, ] } diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index 5eec721841660..8269a4181045f 100644 --- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -3,7 +3,7 @@ import cohere import numpy as np -from cohere.responses import Tokens +from cohere.core import RequestOptions from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult @@ -52,8 +52,8 @@ def _invoke(self, model: str, credentials: dict, text=text ) - for j in range(0, tokenize_response.length, context_size): - tokens += [tokenize_response.token_strings[j: j + context_size]] + for j in range(0, len(tokenize_response), context_size): + tokens += [tokenize_response[j: j + context_size]] indices += [i] batched_embeddings = [] @@ -127,9 +127,9 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int except Exception as e: raise self._transform_invoke_error(e) - return response.length + return len(response) - def _tokenize(self, model: str, credentials: dict, text: str) -> Tokens: + def _tokenize(self, model: str, credentials: dict, text: str) -> list[str]: """ Tokenize text :param model: model name @@ -138,17 +138,19 @@ def _tokenize(self, model: str, credentials: dict, text: str) -> Tokens: :return: """ if not text: - return Tokens([], [], {}) + return [] # initialize client client = cohere.Client(credentials.get('api_key')) response = client.tokenize( text=text, - model=model + model=model, + offline=False, + request_options=RequestOptions(max_retries=0) ) - return response + return response.token_strings def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -184,10 +186,11 @@ def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> response = client.embed( texts=texts, model=model, - input_type='search_document' if len(texts) > 1 else 'search_query' + input_type='search_document' if len(texts) > 1 else 'search_query', + request_options=RequestOptions(max_retries=1) ) - return response.embeddings, response.meta['billed_units']['input_tokens'] + return response.embeddings, int(response.meta.billed_units.input_tokens) def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ @@ -231,13 +234,21 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] """ return { InvokeConnectionError: [ - cohere.CohereConnectionError + cohere.errors.service_unavailable_error.ServiceUnavailableError + ], + InvokeServerUnavailableError: [ + cohere.errors.internal_server_error.InternalServerError + ], + InvokeRateLimitError: [ + cohere.errors.too_many_requests_error.TooManyRequestsError + ], + InvokeAuthorizationError: [ + cohere.errors.unauthorized_error.UnauthorizedError, + cohere.errors.forbidden_error.ForbiddenError ], - InvokeServerUnavailableError: [], - InvokeRateLimitError: [], - InvokeAuthorizationError: [], InvokeBadRequestError: [ - cohere.CohereAPIError, - cohere.CohereError, + cohere.core.api_error.ApiError, + cohere.errors.bad_request_error.BadRequestError, + cohere.errors.not_found_error.NotFoundError, ] } diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml new file mode 100644 index 0000000000000..d65dc02674979 --- /dev/null +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml @@ -0,0 +1,39 @@ +model: gemini-1.5-pro-latest +label: + en_US: Gemini 1.5 Pro +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens_to_sample + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml index ffdc9c3659756..4e9f59e7da94f 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml @@ -4,6 +4,8 @@ label: model_type: llm features: - agent-thought + - tool-call + - stream-tool-call model_properties: mode: chat context_size: 30720 diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 2feff8ebe9cf6..27912b13cc9b5 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -1,7 +1,9 @@ +import json import logging from collections.abc import Generator from typing import Optional, Union +import google.ai.generativelanguage as glm import google.api_core.exceptions as exceptions import google.generativeai as genai import google.generativeai.client as client @@ -13,9 +15,9 @@ AssistantPromptMessage, PromptMessage, PromptMessageContentType, - PromptMessageRole, PromptMessageTool, SystemPromptMessage, + ToolPromptMessage, UserPromptMessage, ) from core.model_runtime.errors.invoke import ( @@ -62,7 +64,7 @@ def _invoke(self, model: str, credentials: dict, :return: full response or stream response chunk generator result """ # invoke model - return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) + return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: @@ -94,6 +96,32 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: ) return text.rstrip() + + def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: + """ + Convert tool messages to glm tools + + :param tools: tool messages + :return: glm tools + """ + return glm.Tool( + function_declarations=[ + glm.FunctionDeclaration( + name=tool.name, + parameters=glm.Schema( + type=glm.Type.OBJECT, + properties={ + key: { + 'type_': value.get('type', 'string').upper(), + 'description': value.get('description', ''), + 'enum': value.get('enum', []) + } for key, value in tool.parameters.get('properties', {}).items() + }, + required=tool.parameters.get('required', []) + ), + ) for tool in tools + ] + ) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -105,7 +133,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: """ try: - ping_message = PromptMessage(content="ping", role="system") + ping_message = SystemPromptMessage(content="ping") self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) except Exception as ex: @@ -114,8 +142,9 @@ def validate_credentials(self, model: str, credentials: dict) -> None: def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -153,7 +182,6 @@ def _generate(self, model: str, credentials: dict, else: history.append(content) - # Create a new ClientManager with tenant's API key new_client_manager = client._ClientManager() new_client_manager.configure(api_key=credentials["google_api_key"]) @@ -167,14 +195,15 @@ def _generate(self, model: str, credentials: dict, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, } - + response = google_model.generate_content( contents=history, generation_config=genai.types.GenerationConfig( **config_kwargs ), stream=stream, - safety_settings=safety_settings + safety_settings=safety_settings, + tools=self._convert_tools_to_glm_tool(tools) if tools else None, ) if stream: @@ -228,43 +257,61 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon """ index = -1 for chunk in response: - content = chunk.text - index += 1 - - assistant_prompt_message = AssistantPromptMessage( - content=content if content else '', - ) - - if not response._done: - - # transform assistant message to prompt message - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + for part in chunk.parts: + assistant_prompt_message = AssistantPromptMessage( + content='' ) - else: - - # calculate num tokens - prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) - - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - finish_reason=chunk.candidates[0].finish_reason, - usage=usage + + if part.text: + assistant_prompt_message.content += part.text + + if part.function_call: + assistant_prompt_message.tool_calls = [ + AssistantPromptMessage.ToolCall( + id=part.function_call.name, + type='function', + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=part.function_call.name, + arguments=json.dumps({ + key: value + for key, value in part.function_call.args.items() + }) + ) + ) + ] + + index += 1 + + if not response._done: + + # transform assistant message to prompt message + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message + ) + ) + else: + + # calculate num tokens + prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message, + finish_reason=chunk.candidates[0].finish_reason, + usage=usage + ) ) - ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: """ @@ -288,6 +335,8 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: message_text = f"{ai_prompt} {content}" elif isinstance(message, SystemPromptMessage): message_text = f"{human_prompt} {content}" + elif isinstance(message, ToolPromptMessage): + message_text = f"{human_prompt} {content}" else: raise ValueError(f"Got unknown type {message}") @@ -300,26 +349,53 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType: :param message: one PromptMessage :return: glm Content representation of message """ - - parts = [] - if (isinstance(message.content, str)): - parts.append(to_part(message.content)) + if isinstance(message, UserPromptMessage): + glm_content = { + "role": "user", + "parts": [] + } + if (isinstance(message.content, str)): + glm_content['parts'].append(to_part(message.content)) + else: + for c in message.content: + if c.type == PromptMessageContentType.TEXT: + glm_content['parts'].append(to_part(c.data)) + else: + metadata, data = c.data.split(',', 1) + mime_type = metadata.split(';', 1)[0].split(':')[1] + blob = {"inline_data":{"mime_type":mime_type,"data":data}} + glm_content['parts'].append(blob) + return glm_content + elif isinstance(message, AssistantPromptMessage): + glm_content = { + "role": "model", + "parts": [] + } + if message.content: + glm_content['parts'].append(to_part(message.content)) + if message.tool_calls: + glm_content["parts"].append(to_part(glm.FunctionCall( + name=message.tool_calls[0].function.name, + args=json.loads(message.tool_calls[0].function.arguments), + ))) + return glm_content + elif isinstance(message, SystemPromptMessage): + return { + "role": "user", + "parts": [to_part(message.content)] + } + elif isinstance(message, ToolPromptMessage): + return { + "role": "function", + "parts": [glm.Part(function_response=glm.FunctionResponse( + name=message.name, + response={ + "response": message.content + } + ))] + } else: - for c in message.content: - if c.type == PromptMessageContentType.TEXT: - parts.append(to_part(c.data)) - else: - metadata, data = c.data.split(',', 1) - mime_type = metadata.split(';', 1)[0].split(':')[1] - blob = {"inline_data":{"mime_type":mime_type,"data":data}} - parts.append(blob) - - glm_content = { - "role": "user" if message.role in (PromptMessageRole.USER, PromptMessageRole.SYSTEM) else "model", - "parts": parts - } - - return glm_content + raise ValueError(f"Got unknown type {message}") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: diff --git a/api/core/rag/retrieval/agent/output_parser/__init__.py b/api/core/model_runtime/model_providers/modelhub/__init__.py similarity index 100% rename from api/core/rag/retrieval/agent/output_parser/__init__.py rename to api/core/model_runtime/model_providers/modelhub/__init__.py diff --git a/api/core/model_runtime/model_providers/modelhub/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/modelhub/_assets/icon_l_en.svg new file mode 100644 index 0000000000000..f834befb0e3d9 --- /dev/null +++ b/api/core/model_runtime/model_providers/modelhub/_assets/icon_l_en.svg @@ -0,0 +1,4 @@ + + + +
ModelHub
\ No newline at end of file diff --git a/api/core/model_runtime/model_providers/modelhub/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/modelhub/_assets/icon_s_en.svg new file mode 100644 index 0000000000000..70686f9b3b58a --- /dev/null +++ b/api/core/model_runtime/model_providers/modelhub/_assets/icon_s_en.svg @@ -0,0 +1,4 @@ + + + + diff --git a/api/core/model_runtime/model_providers/modelhub/_common.py b/api/core/model_runtime/model_providers/modelhub/_common.py new file mode 100644 index 0000000000000..51950ca377842 --- /dev/null +++ b/api/core/model_runtime/model_providers/modelhub/_common.py @@ -0,0 +1,44 @@ + +import requests + +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) + + +class _CommonOAI_API_Compat: + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeAuthorizationError: [ + requests.exceptions.InvalidHeader, # Missing or Invalid API Key + ], + InvokeBadRequestError: [ + requests.exceptions.HTTPError, # Invalid Endpoint URL or model name + requests.exceptions.InvalidURL, # Misconfigured request or other API error + ], + InvokeRateLimitError: [ + requests.exceptions.RetryError # Too many requests sent in a short period of time + ], + InvokeServerUnavailableError: [ + requests.exceptions.ConnectionError, # Engine Overloaded + requests.exceptions.HTTPError # Server Error + ], + InvokeConnectionError: [ + requests.exceptions.ConnectTimeout, # Timeout + requests.exceptions.ReadTimeout # Timeout + ] + } \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/modelhub/llm/Baichuan2-Turbo.yaml b/api/core/model_runtime/model_providers/modelhub/llm/Baichuan2-Turbo.yaml new file mode 100644 index 0000000000000..eeaabff6ea409 --- /dev/null +++ b/api/core/model_runtime/model_providers/modelhub/llm/Baichuan2-Turbo.yaml @@ -0,0 +1,27 @@ +model: Baichuan2-Turbo +label: + zh_Hans: Baichuan2-Turbo + en_US: Baichuan2-Turbo +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 +pricing: + input: "0.003" + output: "0.004" + unit: "0.001" + currency: USD diff --git a/api/core/model_runtime/model_providers/modelhub/llm/__init__.py b/api/core/model_runtime/model_providers/modelhub/llm/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/api/core/model_runtime/model_providers/modelhub/llm/_position.yaml b/api/core/model_runtime/model_providers/modelhub/llm/_position.yaml new file mode 100644 index 0000000000000..efa975755e5a7 --- /dev/null +++ b/api/core/model_runtime/model_providers/modelhub/llm/_position.yaml @@ -0,0 +1,4 @@ +- gpt-3.5-turbo +- gpt-4 +- glm-3-turbo +- glm-4 \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/modelhub/llm/glm-3-turbo.yaml b/api/core/model_runtime/model_providers/modelhub/llm/glm-3-turbo.yaml new file mode 100644 index 0000000000000..06b42a14ba845 --- /dev/null +++ b/api/core/model_runtime/model_providers/modelhub/llm/glm-3-turbo.yaml @@ -0,0 +1,21 @@ +model: glm-3-turbo +label: + zh_Hans: glm-3-turbo + en_US: glm-3-turbo +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.01 +pricing: + input: "0.003" + output: "0.004" + unit: "0.001" + currency: USD diff --git a/api/core/model_runtime/model_providers/modelhub/llm/glm-4.yaml b/api/core/model_runtime/model_providers/modelhub/llm/glm-4.yaml new file mode 100644 index 0000000000000..c23686846a9f1 --- /dev/null +++ b/api/core/model_runtime/model_providers/modelhub/llm/glm-4.yaml @@ -0,0 +1,21 @@ +model: glm-4 +label: + zh_Hans: glm-4 + en_US: glm-4 +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.01 +pricing: + input: "0.003" + output: "0.004" + unit: "0.001" + currency: USD diff --git a/api/core/model_runtime/model_providers/modelhub/llm/gpt-3.5-turbo.yaml b/api/core/model_runtime/model_providers/modelhub/llm/gpt-3.5-turbo.yaml new file mode 100644 index 0000000000000..72a3b7065dc0f --- /dev/null +++ b/api/core/model_runtime/model_providers/modelhub/llm/gpt-3.5-turbo.yaml @@ -0,0 +1,33 @@ +model: gpt-3.5-turbo +label: + zh_Hans: gpt-3.5-turbo + en_US: gpt-3.5-turbo +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: "0.003" + output: "0.004" + unit: "0.001" + currency: USD diff --git a/api/core/model_runtime/model_providers/modelhub/llm/gpt-4.yaml b/api/core/model_runtime/model_providers/modelhub/llm/gpt-4.yaml new file mode 100644 index 0000000000000..415b8dd77bf59 --- /dev/null +++ b/api/core/model_runtime/model_providers/modelhub/llm/gpt-4.yaml @@ -0,0 +1,33 @@ +model: gpt-4 +label: + zh_Hans: gpt-4 + en_US: gpt-4 +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: "0.003" + output: "0.004" + unit: "0.001" + currency: USD diff --git a/api/core/model_runtime/model_providers/modelhub/llm/llm.py b/api/core/model_runtime/model_providers/modelhub/llm/llm.py new file mode 100644 index 0000000000000..1c559077920ea --- /dev/null +++ b/api/core/model_runtime/model_providers/modelhub/llm/llm.py @@ -0,0 +1,782 @@ +import json +import logging +from collections.abc import Generator +from decimal import Decimal +from typing import Optional, Union, cast +from urllib.parse import urljoin + +import requests + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContent, + PromptMessageContentType, + PromptMessageFunction, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, + PriceConfig, +) +from core.model_runtime.errors.invoke import InvokeError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.modelhub._common import _CommonOAI_API_Compat +from core.model_runtime.utils import helper + +logger = logging.getLogger(__name__) + + +class ModelHubLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): + """ + Model class for OpenAI large language model. + """ + + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + + # text completion model + return self._generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user + ) + + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Get number of tokens for given prompt messages + + :param model: + :param credentials: + :param prompt_messages: + :param tools: tools for tool calling + :return: + """ + return self._num_tokens_from_messages(model, prompt_messages, tools) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials using requests to ensure compatibility with all providers following OpenAI's API standard. + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + headers = { + 'Content-Type': 'application/json' + } + + api_key = credentials.get('api_key') + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + endpoint_url = credentials['endpoint_url'] + if not endpoint_url.endswith('/'): + endpoint_url += '/' + + # prepare the payload for a simple ping to the model + data = { + 'model': model, + 'max_tokens': 5 + } + + data['messages'] = [ + { + "role": "user", + "content": "ping" + }, + ] + endpoint_url = urljoin(endpoint_url, 'chat/completions') + + # send a post request to validate the credentials + response = requests.post( + endpoint_url, + headers=headers, + json=data, + timeout=(10, 60) + ) + + if response.status_code != 200: + raise CredentialsValidateFailedError( + f'Credentials validation failed with status code {response.status_code}') + + try: + json_result = response.json() + except json.JSONDecodeError as e: + raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + + if json_result['object'] == '': + json_result['object'] = 'chat.completion' + + if 'object' not in json_result or json_result['object'] != 'chat.completion': + raise CredentialsValidateFailedError( + 'Credentials validation failed: invalid response object, must be \'chat.completion\'') + except CredentialsValidateFailedError: + raise + except Exception as ex: + raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + features = [] + + function_calling_type = credentials.get('function_calling_type', 'no_call') + if function_calling_type in ['function_call']: + features.append(ModelFeature.TOOL_CALL) + elif function_calling_type in ['tool_call']: + features.append(ModelFeature.MULTI_TOOL_CALL) + + stream_function_calling = credentials.get('stream_function_calling', 'supported') + if stream_function_calling == 'supported': + features.append(ModelFeature.STREAM_TOOL_CALL) + + vision_support = credentials.get('vision_support', 'not_support') + if vision_support == 'support': + features.append(ModelFeature.VISION) + + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + features=features, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")), + ModelPropertyKey.MODE: credentials.get('mode'), + }, + parameter_rules=[ + ParameterRule( + name=DefaultParameterName.TEMPERATURE.value, + label=I18nObject(en_US="Temperature"), + type=ParameterType.FLOAT, + default=float(credentials.get('temperature', 0.7)), + min=0, + max=2, + precision=2 + ), + ParameterRule( + name=DefaultParameterName.TOP_P.value, + label=I18nObject(en_US="Top P"), + type=ParameterType.FLOAT, + default=float(credentials.get('top_p', 1)), + min=0, + max=1, + precision=2 + ), + ParameterRule( + name=DefaultParameterName.FREQUENCY_PENALTY.value, + label=I18nObject(en_US="Frequency Penalty"), + type=ParameterType.FLOAT, + default=float(credentials.get('frequency_penalty', 0)), + min=-2, + max=2 + ), + ParameterRule( + name=DefaultParameterName.PRESENCE_PENALTY.value, + label=I18nObject(en_US="Presence Penalty"), + type=ParameterType.FLOAT, + default=float(credentials.get('presence_penalty', 0)), + min=-2, + max=2 + ), + ParameterRule( + name=DefaultParameterName.MAX_TOKENS.value, + label=I18nObject(en_US="Max Tokens"), + type=ParameterType.INT, + default=512, + min=1, + max=int(credentials.get('max_tokens_to_sample', 4096)), + ) + ], + pricing=PriceConfig( + input=Decimal(credentials.get('input_price', 0)), + output=Decimal(credentials.get('output_price', 0)), + unit=Decimal(credentials.get('unit', 0)), + currency=credentials.get('currency', "USD") + ), + ) + + entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value + + return entity + + # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard. + def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, \ + user: Optional[str] = None) -> Union[LLMResult, Generator]: + """ + Invoke llm completion model + + :param model: model name + :param credentials: credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + headers = { + 'Content-Type': 'application/json', + 'Accept-Charset': 'utf-8', + } + + api_key = credentials.get('api_key') + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + endpoint_url = credentials["endpoint_url"] + if not endpoint_url.endswith('/'): + endpoint_url += '/' + + data = { + "model": model, + "stream": stream, + **model_parameters + } + + endpoint_url = urljoin(endpoint_url, 'chat/completions') + data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] + + # annotate tools with names, descriptions, etc. + function_calling_type = credentials.get('function_calling_type', 'no_call') + formatted_tools = [] + if tools: + if function_calling_type == 'function_call': + data['functions'] = [{ + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters + } for tool in tools] + elif function_calling_type == 'tool_call': + data["tool_choice"] = "auto" + + for tool in tools: + formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool))) + + data["tools"] = formatted_tools + + if stop: + data["stop"] = stop + + if user: + data["user"] = user + + response = requests.post( + endpoint_url, + headers=headers, + json=data, + timeout=(10, 60), + stream=stream + ) + + if response.encoding is None or response.encoding == 'ISO-8859-1': + response.encoding = 'utf-8' + + if response.status_code != 200: + raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") + + if stream: + return self._handle_generate_stream_response(model, credentials, response, prompt_messages) + + return self._handle_generate_response(model, credentials, response, prompt_messages) + + def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, + prompt_messages: list[PromptMessage]) -> Generator: + """ + Handle llm stream response + + :param model: model name + :param credentials: model credentials + :param response: streamed response + :param prompt_messages: prompt messages + :return: llm response chunk generator + """ + full_assistant_content = '' + chunk_index = 0 + + def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ + -> LLMResultChunk: + # calculate num tokens + prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) + completion_tokens = self._num_tokens_from_string(model, full_assistant_content) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + return LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=message, + finish_reason=finish_reason, + usage=usage + ) + ) + + # delimiter for stream response, need unicode_escape + import codecs + delimiter = credentials.get("stream_mode_delimiter", "\n\n") + delimiter = codecs.decode(delimiter, "unicode_escape") + + tools_calls: list[AssistantPromptMessage.ToolCall] = [] + + def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]): + def get_tool_call(tool_call_id: str): + if not tool_call_id: + return tools_calls[-1] + + tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None) + if tool_call is None: + tool_call = AssistantPromptMessage.ToolCall( + id=tool_call_id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name="", + arguments="" + ) + ) + tools_calls.append(tool_call) + + return tool_call + + for new_tool_call in new_tool_calls: + # get tool call + tool_call = get_tool_call(new_tool_call.function.name) + # update tool call + if new_tool_call.id: + tool_call.id = new_tool_call.id + if new_tool_call.type: + tool_call.type = new_tool_call.type + if new_tool_call.function.name: + tool_call.function.name = new_tool_call.function.name + if new_tool_call.function.arguments: + tool_call.function.arguments += new_tool_call.function.arguments + + finish_reason = 'Unknown' + + for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): + if chunk: + # ignore sse comments + if chunk.startswith(':'): + continue + decoded_chunk = chunk.strip().lstrip('data: ').lstrip() + + try: + chunk_json = json.loads(decoded_chunk) + # stream ended + except json.JSONDecodeError as e: + yield create_final_llm_result_chunk( + index=chunk_index + 1, + message=AssistantPromptMessage(content=""), + finish_reason="Non-JSON encountered." + ) + break + if not chunk_json or len(chunk_json['choices']) == 0: + continue + + choice = chunk_json['choices'][0] + finish_reason = chunk_json['choices'][0].get('finish_reason') + chunk_index += 1 + + if 'delta' in choice: + delta = choice['delta'] + delta_content = delta.get('content') + + assistant_message_tool_calls = None + + if 'tool_calls' in delta and credentials.get('function_calling_type', 'no_call') == 'tool_call': + assistant_message_tool_calls = delta.get('tool_calls', None) + elif 'function_call' in delta and credentials.get('function_calling_type', 'no_call') == 'function_call': + assistant_message_tool_calls = [{ + 'id': 'tool_call_id', + 'type': 'function', + 'function': delta.get('function_call', {}) + }] + + # assistant_message_function_call = delta.delta.function_call + + # extract tool calls from response + if assistant_message_tool_calls: + tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) + increase_tool_call(tool_calls) + + if delta_content is None or delta_content == '': + continue + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=delta_content, + ) + + # reset tool calls + tool_calls = [] + full_assistant_content += delta_content + elif 'text' in choice: + choice_text = choice.get('text', '') + if choice_text == '': + continue + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage(content=choice_text) + full_assistant_content += choice_text + else: + continue + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=chunk_index, + message=assistant_prompt_message, + ) + ) + + chunk_index += 1 + + if tools_calls: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=chunk_index, + message=AssistantPromptMessage( + tool_calls=tools_calls, + content="" + ), + ) + ) + + yield create_final_llm_result_chunk( + index=chunk_index, + message=AssistantPromptMessage(content=""), + finish_reason=finish_reason + ) + + def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response, + prompt_messages: list[PromptMessage]) -> LLMResult: + + response_json = response.json() + + output = response_json['choices'][0] + + response_content = '' + tool_calls = None + function_calling_type = credentials.get('function_calling_type', 'no_call') + response_content = output.get('message', {})['content'] + if function_calling_type == 'tool_call': + tool_calls = output.get('message', {}).get('tool_calls') + elif function_calling_type == 'function_call': + tool_calls = output.get('message', {}).get('function_call') + + + assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[]) + + if tool_calls: + if function_calling_type == 'tool_call': + assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls) + elif function_calling_type == 'function_call': + assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)] + + usage = response_json.get("usage") + if usage: + # transform usage + prompt_tokens = usage["prompt_tokens"] + completion_tokens = usage["completion_tokens"] + else: + # calculate num tokens + prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) + completion_tokens = self._num_tokens_from_string(model, assistant_message.content) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + # transform response + result = LLMResult( + model=response_json["model"], + prompt_messages=prompt_messages, + message=assistant_message, + usage=usage, + ) + + return result + + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: + """ + Convert PromptMessage to dict for OpenAI API format + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + sub_messages = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(PromptMessageContent, message_content) + sub_message_dict = { + "type": "text", + "text": message_content.data + } + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + sub_message_dict = { + "type": "image_url", + "image_url": { + "url": message_content.data, + "detail": message_content.detail.value + } + } + sub_messages.append(sub_message_dict) + + message_dict = {"role": "user", "content": sub_messages} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + if message.tool_calls: + # message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call + # in + # message.tool_calls] + + function_call = message.tool_calls[0] + message_dict["function_call"] = { + "name": function_call.function.name, + "arguments": function_call.function.arguments, + } + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + # message_dict = { + # "role": "tool", + # "content": message.content, + # "tool_call_id": message.tool_call_id + # } + message_dict = { + "role": "function", + "content": message.content, + "name": message.tool_call_id + } + else: + raise ValueError(f"Got unknown type {message}") + + if message.name: + message_dict["name"] = message.name + + return message_dict + + def _num_tokens_from_string(self, model: str, text: Union[str, list[PromptMessageContent]], + tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Approximate num tokens for model with gpt2 tokenizer. + + :param model: model name + :param text: prompt text + :param tools: tools for tool calling + :return: number of tokens + """ + if isinstance(text, str): + full_text = text + else: + full_text = '' + for message_content in text: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(PromptMessageContent, message_content) + full_text += message_content.data + + num_tokens = self._get_num_tokens_by_gpt2(full_text) + + if tools: + num_tokens += self._num_tokens_for_tools(tools) + + return num_tokens + + def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Approximate num tokens with GPT2 tokenizer. + """ + + tokens_per_message = 3 + tokens_per_name = 1 + + num_tokens = 0 + messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] + for message in messages_dict: + num_tokens += tokens_per_message + for key, value in message.items(): + # Cast str(value) in case the message value is not a string + # This occurs with function messages + # TODO: The current token calculation method for the image type is not implemented, + # which need to download the image and then get the resolution for calculation, + # and will increase the request delay + if isinstance(value, list): + text = '' + for item in value: + if isinstance(item, dict) and item['type'] == 'text': + text += item['text'] + + value = text + + if key == "tool_calls": + for tool_call in value: + for t_key, t_value in tool_call.items(): + num_tokens += self._get_num_tokens_by_gpt2(t_key) + if t_key == "function": + for f_key, f_value in t_value.items(): + num_tokens += self._get_num_tokens_by_gpt2(f_key) + num_tokens += self._get_num_tokens_by_gpt2(f_value) + else: + num_tokens += self._get_num_tokens_by_gpt2(t_key) + num_tokens += self._get_num_tokens_by_gpt2(t_value) + else: + num_tokens += self._get_num_tokens_by_gpt2(str(value)) + + if key == "name": + num_tokens += tokens_per_name + + # every reply is primed with assistant + num_tokens += 3 + + if tools: + num_tokens += self._num_tokens_for_tools(tools) + + return num_tokens + + def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int: + """ + Calculate num tokens for tool calling with tiktoken package. + + :param tools: tools for tool calling + :return: number of tokens + """ + num_tokens = 0 + for tool in tools: + num_tokens += self._get_num_tokens_by_gpt2('type') + num_tokens += self._get_num_tokens_by_gpt2('function') + num_tokens += self._get_num_tokens_by_gpt2('function') + + # calculate num tokens for function object + num_tokens += self._get_num_tokens_by_gpt2('name') + num_tokens += self._get_num_tokens_by_gpt2(tool.name) + num_tokens += self._get_num_tokens_by_gpt2('description') + num_tokens += self._get_num_tokens_by_gpt2(tool.description) + parameters = tool.parameters + num_tokens += self._get_num_tokens_by_gpt2('parameters') + if 'title' in parameters: + num_tokens += self._get_num_tokens_by_gpt2('title') + num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title")) + num_tokens += self._get_num_tokens_by_gpt2('type') + num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type")) + if 'properties' in parameters: + num_tokens += self._get_num_tokens_by_gpt2('properties') + for key, value in parameters.get('properties').items(): + num_tokens += self._get_num_tokens_by_gpt2(key) + for field_key, field_value in value.items(): + num_tokens += self._get_num_tokens_by_gpt2(field_key) + if field_key == 'enum': + for enum_field in field_value: + num_tokens += 3 + num_tokens += self._get_num_tokens_by_gpt2(enum_field) + else: + num_tokens += self._get_num_tokens_by_gpt2(field_key) + num_tokens += self._get_num_tokens_by_gpt2(str(field_value)) + if 'required' in parameters: + num_tokens += self._get_num_tokens_by_gpt2('required') + for required_field in parameters['required']: + num_tokens += 3 + num_tokens += self._get_num_tokens_by_gpt2(required_field) + + return num_tokens + + def _extract_response_tool_calls(self, + response_tool_calls: list[dict]) \ + -> list[AssistantPromptMessage.ToolCall]: + """ + Extract tool calls from response + + :param response_tool_calls: response tool calls + :return: list of tool calls + """ + tool_calls = [] + if response_tool_calls: + for response_tool_call in response_tool_calls: + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_tool_call.get("function", {}).get("name", ""), + arguments=response_tool_call.get("function", {}).get("arguments", "") + ) + + tool_call = AssistantPromptMessage.ToolCall( + id=response_tool_call.get("id", ""), + type=response_tool_call.get("type", ""), + function=function + ) + tool_calls.append(tool_call) + + return tool_calls + + def _extract_response_function_call(self, response_function_call) \ + -> AssistantPromptMessage.ToolCall: + """ + Extract function call from response + + :param response_function_call: response function call + :return: tool call + """ + tool_call = None + if response_function_call: + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_function_call.get('name', ''), + arguments=response_function_call.get('arguments', '') + ) + + tool_call = AssistantPromptMessage.ToolCall( + id=response_function_call.get('id', ''), + type="function", + function=function + ) + + return tool_call diff --git a/api/core/model_runtime/model_providers/modelhub/modelhub.py b/api/core/model_runtime/model_providers/modelhub/modelhub.py new file mode 100644 index 0000000000000..344c835308eb4 --- /dev/null +++ b/api/core/model_runtime/model_providers/modelhub/modelhub.py @@ -0,0 +1,11 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class ModelHubProvider(ModelProvider): + + def validate_provider_credentials(self, credentials: dict) -> None: + pass diff --git a/api/core/model_runtime/model_providers/modelhub/modelhub.yaml b/api/core/model_runtime/model_providers/modelhub/modelhub.yaml new file mode 100644 index 0000000000000..8c13bab976db9 --- /dev/null +++ b/api/core/model_runtime/model_providers/modelhub/modelhub.yaml @@ -0,0 +1,152 @@ +provider: modelhub +label: + en_US: ModelHub +description: + en_US: ModelHub + zh_Hans: ModelHub +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.svg +background: "#FFFFFF" +supported_model_types: + - llm + - text-embedding + - rerank +configurate_methods: + - predefined-model + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter full model name + zh_Hans: 输入模型全称 + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: false + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: endpoint_url + label: + zh_Hans: API endpoint URL + en_US: API endpoint URL + type: text-input + required: false + default: 'https://modelhub.puyuan.tech/api/v1' + placeholder: + zh_Hans: Base URL, e.g. https://api.openai.com/v1 + en_US: Base URL, e.g. https://api.openai.com/v1 + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: '4096' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens_to_sample + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + show_on: + - variable: __model_type + value: llm + default: '4096' + type: text-input + - variable: function_calling_type + show_on: + - variable: __model_type + value: llm + label: + en_US: Function calling + type: select + required: false + default: no_call + options: + - value: function_call + label: + en_US: Function Call + zh_Hans: Function Call + - value: tool_call + label: + en_US: Tool Call + zh_Hans: Tool Call + - value: no_call + label: + en_US: Not Support + zh_Hans: 不支持 + - variable: stream_function_calling + show_on: + - variable: __model_type + value: llm + label: + en_US: Stream function calling + type: select + required: false + default: not_supported + options: + - value: supported + label: + en_US: Support + zh_Hans: 支持 + - value: not_supported + label: + en_US: Not Support + zh_Hans: 不支持 + - variable: vision_support + show_on: + - variable: __model_type + value: llm + label: + zh_Hans: Vision 支持 + en_US: Vision Support + type: select + required: false + default: no_support + options: + - value: support + label: + en_US: Support + zh_Hans: 支持 + - value: no_support + label: + en_US: Not Support + zh_Hans: 不支持 + - variable: stream_mode_delimiter + label: + zh_Hans: 流模式返回结果的分隔符 + en_US: Delimiter for streaming results + show_on: + - variable: __model_type + value: llm + default: '\n\n' + type: text-input +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: endpoint_url + label: + zh_Hans: endpoint_url + en_US: endpoint_url + type: text-input + required: false + default: 'https://modelhub.puyuan.tech/api/v1' + placeholder: + zh_Hans: 在此输入您的 endpoint_url + en_US: Enter your endpoint_url \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/modelhub/rerank/__init__.py b/api/core/model_runtime/model_providers/modelhub/rerank/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/api/core/model_runtime/model_providers/modelhub/rerank/bge-reranker-base.yaml b/api/core/model_runtime/model_providers/modelhub/rerank/bge-reranker-base.yaml new file mode 100644 index 0000000000000..62831c8bb9776 --- /dev/null +++ b/api/core/model_runtime/model_providers/modelhub/rerank/bge-reranker-base.yaml @@ -0,0 +1,4 @@ +model: bge-reranker-base +model_type: rerank +model_properties: + context_size: 5120 diff --git a/api/core/model_runtime/model_providers/modelhub/rerank/bge-reranker-v2-m3.yaml b/api/core/model_runtime/model_providers/modelhub/rerank/bge-reranker-v2-m3.yaml new file mode 100644 index 0000000000000..0ee6309ff29e5 --- /dev/null +++ b/api/core/model_runtime/model_providers/modelhub/rerank/bge-reranker-v2-m3.yaml @@ -0,0 +1,4 @@ +model: bge-reranker-v2-m3 +model_type: rerank +model_properties: + context_size: 5120 diff --git a/api/core/model_runtime/model_providers/modelhub/rerank/rerank.py b/api/core/model_runtime/model_providers/modelhub/rerank/rerank.py new file mode 100644 index 0000000000000..2e2623eccfc68 --- /dev/null +++ b/api/core/model_runtime/model_providers/modelhub/rerank/rerank.py @@ -0,0 +1,100 @@ +from typing import Optional + +import numpy as np +from modelhub import ModelhubClient + +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class CohereRerankModel(RerankModel): + """ + Model class for Cohere rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=docs) + + user_name, user_password = credentials["api_key"].split(":") + client = ModelhubClient( + user_name=user_name, + user_password=user_password, + host=credentials["endpoint_url"].replace("v1", ""), + ) + scores = client.cross_embedding( + [[query, doc] for doc in docs], + model=model, + ).scores + sort_idx = np.argsort(scores)[::-1] + rerank_documents = [] + for i in sort_idx[:top_n] if top_n is not None else sort_idx: + rerank_document = RerankDocument(index=i, text=docs[i], score=scores[i]) + if score_threshold is not None: + if scores[i] >= score_threshold: + rerank_documents.append(rerank_document) + else: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self.invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return {} diff --git a/api/core/model_runtime/model_providers/modelhub/text_embedding/__init__.py b/api/core/model_runtime/model_providers/modelhub/text_embedding/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/api/core/model_runtime/model_providers/modelhub/text_embedding/bge-m3.yaml b/api/core/model_runtime/model_providers/modelhub/text_embedding/bge-m3.yaml new file mode 100644 index 0000000000000..13571b282e3ab --- /dev/null +++ b/api/core/model_runtime/model_providers/modelhub/text_embedding/bge-m3.yaml @@ -0,0 +1,9 @@ +model: bge-m3 +model_type: text-embedding +model_properties: + context_size: 8191 + max_chunks: 32 +pricing: + input: '0.00013' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/modelhub/text_embedding/m3e-large.yaml b/api/core/model_runtime/model_providers/modelhub/text_embedding/m3e-large.yaml new file mode 100644 index 0000000000000..005a87f3ef950 --- /dev/null +++ b/api/core/model_runtime/model_providers/modelhub/text_embedding/m3e-large.yaml @@ -0,0 +1,9 @@ +model: m3e-large +model_type: text-embedding +model_properties: + context_size: 8191 + max_chunks: 32 +pricing: + input: '0.00013' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/modelhub/text_embedding/text-embedding-3-large.yaml b/api/core/model_runtime/model_providers/modelhub/text_embedding/text-embedding-3-large.yaml new file mode 100644 index 0000000000000..9489170fdea0b --- /dev/null +++ b/api/core/model_runtime/model_providers/modelhub/text_embedding/text-embedding-3-large.yaml @@ -0,0 +1,9 @@ +model: text-embedding-3-large +model_type: text-embedding +model_properties: + context_size: 8191 + max_chunks: 32 +pricing: + input: '0.00013' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/modelhub/text_embedding/text-embedding-3-small.yaml b/api/core/model_runtime/model_providers/modelhub/text_embedding/text-embedding-3-small.yaml new file mode 100644 index 0000000000000..9f46210e55b97 --- /dev/null +++ b/api/core/model_runtime/model_providers/modelhub/text_embedding/text-embedding-3-small.yaml @@ -0,0 +1,9 @@ +model: text-embedding-3-small +model_type: text-embedding +model_properties: + context_size: 8191 + max_chunks: 32 +pricing: + input: '0.00013' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/modelhub/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/modelhub/text_embedding/text_embedding.py new file mode 100644 index 0000000000000..a1e0e5143fb8c --- /dev/null +++ b/api/core/model_runtime/model_providers/modelhub/text_embedding/text_embedding.py @@ -0,0 +1,244 @@ +import json +import time +from decimal import Decimal +from typing import Optional +from urllib.parse import urljoin + +import numpy as np +import requests + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + PriceConfig, + PriceType, +) +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.model_providers.modelhub._common import _CommonOAI_API_Compat + + +class ModelHubEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): + """ + Model class for an OpenAI API-compatible text embedding model. + """ + + def _invoke(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + + # Prepare headers and payload for the request + headers = { + 'Content-Type': 'application/json' + } + + api_key = credentials.get('api_key') + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + endpoint_url = credentials.get('endpoint_url') + if not endpoint_url.endswith('/'): + endpoint_url += '/' + + endpoint_url = urljoin(endpoint_url, 'embeddings') + + extra_model_kwargs = {} + if user: + extra_model_kwargs['user'] = user + + extra_model_kwargs['encoding_format'] = 'float' + + # get model properties + context_size = self._get_context_size(model, credentials) + max_chunks = self._get_max_chunks(model, credentials) + + inputs = [] + indices = [] + used_tokens = 0 + + for i, text in enumerate(texts): + + # Here token count is only an approximation based on the GPT2 tokenizer + # TODO: Optimize for better token estimation and chunking + num_tokens = self._get_num_tokens_by_gpt2(text) + + if num_tokens >= context_size: + cutoff = int(len(text) * (np.floor(context_size / num_tokens))) + # if num tokens is larger than context length, only use the start + inputs.append(text[0: cutoff]) + else: + inputs.append(text) + indices += [i] + + batched_embeddings = [] + _iter = range(0, len(inputs), max_chunks) + + for i in _iter: + # Prepare the payload for the request + payload = { + 'input': inputs[i: i + max_chunks], + 'model': model, + **extra_model_kwargs + } + + # Make the request to the OpenAI API + response = requests.post( + endpoint_url, + headers=headers, + data=json.dumps(payload), + timeout=(10, 300) + ) + + response.raise_for_status() # Raise an exception for HTTP errors + response_data = response.json() + + # Extract embeddings and used tokens from the response + embeddings_batch = [data['embedding'] for data in response_data['data']] + embedding_used_tokens = response_data['usage']['total_tokens'] + + used_tokens += embedding_used_tokens + batched_embeddings += embeddings_batch + + # calc usage + usage = self._calc_response_usage( + model=model, + credentials=credentials, + tokens=used_tokens + ) + + return TextEmbeddingResult( + embeddings=batched_embeddings, + usage=usage, + model=model + ) + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Approximate number of tokens for given messages using GPT2 tokenizer + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + return sum(self._get_num_tokens_by_gpt2(text) for text in texts) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + headers = { + 'Content-Type': 'application/json' + } + + api_key = credentials.get('api_key') + + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + endpoint_url = credentials.get('endpoint_url') + if not endpoint_url.endswith('/'): + endpoint_url += '/' + + endpoint_url = urljoin(endpoint_url, 'embeddings') + + payload = { + 'input': 'ping', + 'model': model + } + + response = requests.post( + url=endpoint_url, + headers=headers, + data=json.dumps(payload), + timeout=(10, 300) + ) + + if response.status_code != 200: + raise CredentialsValidateFailedError( + f'Credentials validation failed with status code {response.status_code}') + + try: + json_result = response.json() + except json.JSONDecodeError as e: + raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + + if 'model' not in json_result: + raise CredentialsValidateFailedError( + 'Credentials validation failed: invalid response') + except CredentialsValidateFailedError: + raise + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.TEXT_EMBEDDING, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.MAX_CHUNKS: 1, + }, + parameter_rules=[], + pricing=PriceConfig( + input=Decimal(credentials.get('input_price', 0)), + unit=Decimal(credentials.get('unit', 0)), + currency=credentials.get('currency', "USD") + ) + ) + + return entity + + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage diff --git a/api/core/model_runtime/model_providers/moonshot/llm/llm.py b/api/core/model_runtime/model_providers/moonshot/llm/llm.py index 05feee877eee9..3e146559c85c5 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/llm.py +++ b/api/core/model_runtime/model_providers/moonshot/llm/llm.py @@ -1,8 +1,31 @@ +import json from collections.abc import Generator -from typing import Optional, Union +from typing import Optional, Union, cast -from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +import requests + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContent, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel @@ -13,6 +36,7 @@ def _invoke(self, model: str, credentials: dict, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) + self._add_function_call(model, credentials) user = user[:32] if user else None return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -20,7 +44,286 @@ def validate_credentials(self, model: str, credentials: dict) -> None: self._add_custom_parameters(credentials) super().validate_credentials(model, credentials) - @staticmethod - def _add_custom_parameters(credentials: dict) -> None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + return AIModelEntity( + model=model, + label=I18nObject(en_US=model, zh_Hans=model), + model_type=ModelType.LLM, + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get('function_calling_type') == 'tool_call' + else [], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)), + ModelPropertyKey.MODE: LLMMode.CHAT.value, + }, + parameter_rules=[ + ParameterRule( + name='temperature', + use_template='temperature', + label=I18nObject(en_US='Temperature', zh_Hans='温度'), + type=ParameterType.FLOAT, + ), + ParameterRule( + name='max_tokens', + use_template='max_tokens', + default=512, + min=1, + max=int(credentials.get('max_tokens', 4096)), + label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'), + type=ParameterType.INT, + ), + ParameterRule( + name='top_p', + use_template='top_p', + label=I18nObject(en_US='Top P', zh_Hans='Top P'), + type=ParameterType.FLOAT, + ), + ] + ) + + def _add_custom_parameters(self, credentials: dict) -> None: credentials['mode'] = 'chat' credentials['endpoint_url'] = 'https://api.moonshot.cn/v1' + + def _add_function_call(self, model: str, credentials: dict) -> None: + model_schema = self.get_model_schema(model, credentials) + if model_schema and set([ + ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL + ]).intersection(model_schema.features or []): + credentials['function_calling_type'] = 'tool_call' + + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: + """ + Convert PromptMessage to dict for OpenAI API format + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + sub_messages = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(PromptMessageContent, message_content) + sub_message_dict = { + "type": "text", + "text": message_content.data + } + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + sub_message_dict = { + "type": "image_url", + "image_url": { + "url": message_content.data, + "detail": message_content.detail.value + } + } + sub_messages.append(sub_message_dict) + message_dict = {"role": "user", "content": sub_messages} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + if message.tool_calls: + message_dict["tool_calls"] = [] + for function_call in message.tool_calls: + message_dict["tool_calls"].append({ + "id": function_call.id, + "type": function_call.type, + "function": { + "name": function_call.function.name, + "arguments": function_call.function.arguments + } + }) + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + + if message.name: + message_dict["name"] = message.name + + return message_dict + + def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]: + """ + Extract tool calls from response + + :param response_tool_calls: response tool calls + :return: list of tool calls + """ + tool_calls = [] + if response_tool_calls: + for response_tool_call in response_tool_calls: + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "", + arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else "" + ) + + tool_call = AssistantPromptMessage.ToolCall( + id=response_tool_call["id"] if response_tool_call.get("id") else "", + type=response_tool_call["type"] if response_tool_call.get("type") else "", + function=function + ) + tool_calls.append(tool_call) + + return tool_calls + + def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, + prompt_messages: list[PromptMessage]) -> Generator: + """ + Handle llm stream response + + :param model: model name + :param credentials: model credentials + :param response: streamed response + :param prompt_messages: prompt messages + :return: llm response chunk generator + """ + full_assistant_content = '' + chunk_index = 0 + + def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ + -> LLMResultChunk: + # calculate num tokens + prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) + completion_tokens = self._num_tokens_from_string(model, full_assistant_content) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + return LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=message, + finish_reason=finish_reason, + usage=usage + ) + ) + + tools_calls: list[AssistantPromptMessage.ToolCall] = [] + finish_reason = "Unknown" + + def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]): + def get_tool_call(tool_name: str): + if not tool_name: + return tools_calls[-1] + + tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None) + if tool_call is None: + tool_call = AssistantPromptMessage.ToolCall( + id='', + type='', + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="") + ) + tools_calls.append(tool_call) + + return tool_call + + for new_tool_call in new_tool_calls: + # get tool call + tool_call = get_tool_call(new_tool_call.function.name) + # update tool call + if new_tool_call.id: + tool_call.id = new_tool_call.id + if new_tool_call.type: + tool_call.type = new_tool_call.type + if new_tool_call.function.name: + tool_call.function.name = new_tool_call.function.name + if new_tool_call.function.arguments: + tool_call.function.arguments += new_tool_call.function.arguments + + for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"): + if chunk: + # ignore sse comments + if chunk.startswith(':'): + continue + decoded_chunk = chunk.strip().lstrip('data: ').lstrip() + chunk_json = None + try: + chunk_json = json.loads(decoded_chunk) + # stream ended + except json.JSONDecodeError as e: + yield create_final_llm_result_chunk( + index=chunk_index + 1, + message=AssistantPromptMessage(content=""), + finish_reason="Non-JSON encountered." + ) + break + if not chunk_json or len(chunk_json['choices']) == 0: + continue + + choice = chunk_json['choices'][0] + finish_reason = chunk_json['choices'][0].get('finish_reason') + chunk_index += 1 + + if 'delta' in choice: + delta = choice['delta'] + delta_content = delta.get('content') + + assistant_message_tool_calls = delta.get('tool_calls', None) + # assistant_message_function_call = delta.delta.function_call + + # extract tool calls from response + if assistant_message_tool_calls: + tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) + increase_tool_call(tool_calls) + + if delta_content is None or delta_content == '': + continue + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=delta_content, + tool_calls=tool_calls if assistant_message_tool_calls else [] + ) + + full_assistant_content += delta_content + elif 'text' in choice: + choice_text = choice.get('text', '') + if choice_text == '': + continue + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage(content=choice_text) + full_assistant_content += choice_text + else: + continue + + # check payload indicator for completion + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=chunk_index, + message=assistant_prompt_message, + ) + ) + + chunk_index += 1 + + if tools_calls: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=chunk_index, + message=AssistantPromptMessage( + tool_calls=tools_calls, + content="" + ), + ) + ) + + yield create_final_llm_result_chunk( + index=chunk_index, + message=AssistantPromptMessage(content=""), + finish_reason=finish_reason + ) \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/moonshot/moonshot.yaml b/api/core/model_runtime/model_providers/moonshot/moonshot.yaml index 1885ee9d9444b..34c802c2a7afd 100644 --- a/api/core/model_runtime/model_providers/moonshot/moonshot.yaml +++ b/api/core/model_runtime/model_providers/moonshot/moonshot.yaml @@ -20,6 +20,7 @@ supported_model_types: - llm configurate_methods: - predefined-model + - customizable-model provider_credential_schema: credential_form_schemas: - variable: api_key @@ -30,3 +31,51 @@ provider_credential_schema: placeholder: zh_Hans: 在此输入您的 API Key en_US: Enter your API Key +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: '4096' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + default: '4096' + type: text-input + - variable: function_calling_type + label: + en_US: Function calling + type: select + required: false + default: no_call + options: + - value: no_call + label: + en_US: Not supported + zh_Hans: 不支持 + - value: tool_call + label: + en_US: Tool Call + zh_Hans: Tool Call diff --git a/api/core/model_runtime/model_providers/nvidia/llm/_position.yaml b/api/core/model_runtime/model_providers/nvidia/llm/_position.yaml index 78ab4cb93eaa4..51e71920e82f4 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/nvidia/llm/_position.yaml @@ -1,4 +1,5 @@ - google/gemma-7b +- google/codegemma-7b - meta/llama2-70b - mistralai/mixtral-8x7b-instruct-v0.1 - fuyu-8b diff --git a/api/core/model_runtime/model_providers/nvidia/llm/codegemma-7b.yaml b/api/core/model_runtime/model_providers/nvidia/llm/codegemma-7b.yaml new file mode 100644 index 0000000000000..ae94b14220906 --- /dev/null +++ b/api/core/model_runtime/model_providers/nvidia/llm/codegemma-7b.yaml @@ -0,0 +1,30 @@ +model: google/codegemma-7b +label: + zh_Hans: google/codegemma-7b + en_US: google/codegemma-7b +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 1024 + - name: frequency_penalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: presence_penalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 diff --git a/api/core/model_runtime/model_providers/nvidia/llm/llm.py b/api/core/model_runtime/model_providers/nvidia/llm/llm.py index 5d05e606b05cf..81291bf6c4921 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/llm.py +++ b/api/core/model_runtime/model_providers/nvidia/llm/llm.py @@ -24,6 +24,7 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): 'fuyu-8b': 'vlm/adept/fuyu-8b', 'mistralai/mixtral-8x7b-instruct-v0.1': '', 'google/gemma-7b': '', + 'google/codegemma-7b': '', 'meta/llama2-70b': '' } diff --git a/api/core/model_runtime/model_providers/nvidia/nvidia.yaml b/api/core/model_runtime/model_providers/nvidia/nvidia.yaml index c3c316321e60d..4d6da913c16db 100644 --- a/api/core/model_runtime/model_providers/nvidia/nvidia.yaml +++ b/api/core/model_runtime/model_providers/nvidia/nvidia.yaml @@ -1,6 +1,6 @@ provider: nvidia label: - en_US: NVIDIA + en_US: API Catalog icon_small: en_US: icon_s_en.svg icon_large: diff --git a/api/core/model_runtime/model_providers/nvidia/rerank/rerank-qa-mistral-4b.yaml b/api/core/model_runtime/model_providers/nvidia/rerank/rerank-qa-mistral-4b.yaml index 7703ca21abacf..461f4e1cbe47a 100644 --- a/api/core/model_runtime/model_providers/nvidia/rerank/rerank-qa-mistral-4b.yaml +++ b/api/core/model_runtime/model_providers/nvidia/rerank/rerank-qa-mistral-4b.yaml @@ -1,4 +1,4 @@ model: nv-rerank-qa-mistral-4b:1 model_type: rerank model_properties: - context_size: 8192 + context_size: 512 diff --git a/api/core/model_runtime/model_providers/ollama/ollama.yaml b/api/core/model_runtime/model_providers/ollama/ollama.yaml index 782667fdab0dc..33747753bd9f6 100644 --- a/api/core/model_runtime/model_providers/ollama/ollama.yaml +++ b/api/core/model_runtime/model_providers/ollama/ollama.yaml @@ -90,9 +90,9 @@ model_credential_schema: options: - value: 'true' label: - en_US: Yes + en_US: 'Yes' zh_Hans: 是 - value: 'false' label: - en_US: No + en_US: 'No' zh_Hans: 否 diff --git a/api/core/model_runtime/model_providers/openai/llm/_position.yaml b/api/core/model_runtime/model_providers/openai/llm/_position.yaml index cc3f3a6d940d3..3808d670c3320 100644 --- a/api/core/model_runtime/model_providers/openai/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/_position.yaml @@ -1,4 +1,6 @@ - gpt-4 +- gpt-4-turbo +- gpt-4-turbo-2024-04-09 - gpt-4-turbo-preview - gpt-4-32k - gpt-4-1106-preview diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-2024-04-09.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-2024-04-09.yaml new file mode 100644 index 0000000000000..6b36361efe80d --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-2024-04-09.yaml @@ -0,0 +1,57 @@ +model: gpt-4-turbo-2024-04-09 +label: + zh_Hans: gpt-4-turbo-2024-04-09 + en_US: gpt-4-turbo-2024-04-09 +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 4096 + - name: seed + label: + zh_Hans: 种子 + en_US: Seed + type: int + help: + zh_Hans: 如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint + 响应参数来监视变化。 + en_US: If specified, model will make a best effort to sample deterministically, + such that repeated requests with the same seed and parameters should return + the same result. Determinism is not guaranteed, and you should refer to the + system_fingerprint response parameter to monitor changes in the backend. + required: false + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '0.01' + output: '0.03' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo.yaml new file mode 100644 index 0000000000000..575acb7fa294b --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo.yaml @@ -0,0 +1,57 @@ +model: gpt-4-turbo +label: + zh_Hans: gpt-4-turbo + en_US: gpt-4-turbo +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 4096 + - name: seed + label: + zh_Hans: 种子 + en_US: Seed + type: int + help: + zh_Hans: 如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint + 响应参数来监视变化。 + en_US: If specified, model will make a best effort to sample deterministically, + such that repeated requests with the same seed and parameters should return + the same result. Determinism is not guaranteed, and you should refer to the + system_fingerprint response parameter to monitor changes in the backend. + required: false + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '0.01' + output: '0.03' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 46f17fe19b6f9..b7db39376c8a7 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -547,6 +547,9 @@ def _chat_generate(self, model: str, credentials: dict, if user: extra_model_kwargs['user'] = user + # clear illegal prompt messages + prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) + # chat model response = client.chat.completions.create( messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], @@ -757,6 +760,31 @@ def _extract_response_function_call(self, response_function_call: FunctionCall | return tool_call + def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + Clear illegal prompt messages for OpenAI API + + :param model: model name + :param prompt_messages: prompt messages + :return: cleaned prompt messages + """ + checklist = ['gpt-4-turbo', 'gpt-4-turbo-2024-04-09'] + + if model in checklist: + # count how many user messages are there + user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)]) + if user_message_count > 1: + for prompt_message in prompt_messages: + if isinstance(prompt_message, UserPromptMessage): + if isinstance(prompt_message.content, list): + prompt_message.content = '\n'.join([ + item.data if item.type == PromptMessageContentType.TEXT else + '[IMAGE]' if item.type == PromptMessageContentType.IMAGE else '' + for item in prompt_message.content + ]) + + return prompt_messages + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: """ Convert PromptMessage to dict for OpenAI API diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 8cfec0e34b2f3..45a5b49a8b0d7 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -150,6 +150,11 @@ def validate_credentials(self, model: str, credentials: dict) -> None: except json.JSONDecodeError as e: raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + if (completion_type is LLMMode.CHAT and json_result['object'] == ''): + json_result['object'] = 'chat.completion' + elif (completion_type is LLMMode.COMPLETION and json_result['object'] == ''): + json_result['object'] = 'text_completion' + if (completion_type is LLMMode.CHAT and ('object' not in json_result or json_result['object'] != 'chat.completion')): raise CredentialsValidateFailedError( @@ -167,23 +172,28 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode """ generate custom model entities from credentials """ - support_function_call = False features = [] + function_calling_type = credentials.get('function_calling_type', 'no_call') - if function_calling_type == 'function_call': - features = [ModelFeature.TOOL_CALL] - support_function_call = True - endpoint_url = credentials["endpoint_url"] - # if not endpoint_url.endswith('/'): - # endpoint_url += '/' - # if 'https://api.openai.com/v1/' == endpoint_url: - # features = [ModelFeature.STREAM_TOOL_CALL] + if function_calling_type in ['function_call']: + features.append(ModelFeature.TOOL_CALL) + elif function_calling_type in ['tool_call']: + features.append(ModelFeature.MULTI_TOOL_CALL) + + stream_function_calling = credentials.get('stream_function_calling', 'supported') + if stream_function_calling == 'supported': + features.append(ModelFeature.STREAM_TOOL_CALL) + + vision_support = credentials.get('vision_support', 'not_support') + if vision_support == 'support': + features.append(ModelFeature.VISION) + entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - features=features if support_function_call else [], + features=features, model_properties={ ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")), ModelPropertyKey.MODE: credentials.get('mode'), @@ -378,13 +388,49 @@ def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, f delimiter = credentials.get("stream_mode_delimiter", "\n\n") delimiter = codecs.decode(delimiter, "unicode_escape") + tools_calls: list[AssistantPromptMessage.ToolCall] = [] + + def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]): + def get_tool_call(tool_call_id: str): + if not tool_call_id: + return tools_calls[-1] + + tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None) + if tool_call is None: + tool_call = AssistantPromptMessage.ToolCall( + id=tool_call_id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name="", + arguments="" + ) + ) + tools_calls.append(tool_call) + + return tool_call + + for new_tool_call in new_tool_calls: + # get tool call + tool_call = get_tool_call(new_tool_call.function.name) + # update tool call + if new_tool_call.id: + tool_call.id = new_tool_call.id + if new_tool_call.type: + tool_call.type = new_tool_call.type + if new_tool_call.function.name: + tool_call.function.name = new_tool_call.function.name + if new_tool_call.function.arguments: + tool_call.function.arguments += new_tool_call.function.arguments + + finish_reason = 'Unknown' + for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): if chunk: # ignore sse comments if chunk.startswith(':'): continue decoded_chunk = chunk.strip().lstrip('data: ').lstrip() - chunk_json = None + try: chunk_json = json.loads(decoded_chunk) # stream ended @@ -405,24 +451,35 @@ def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, f if 'delta' in choice: delta = choice['delta'] delta_content = delta.get('content') - if delta_content is None or delta_content == '': - continue - assistant_message_tool_calls = delta.get('tool_calls', None) + assistant_message_tool_calls = None + + if 'tool_calls' in delta and credentials.get('function_calling_type', 'no_call') == 'tool_call': + assistant_message_tool_calls = delta.get('tool_calls', None) + elif 'function_call' in delta and credentials.get('function_calling_type', 'no_call') == 'function_call': + assistant_message_tool_calls = [{ + 'id': 'tool_call_id', + 'type': 'function', + 'function': delta.get('function_call', {}) + }] + # assistant_message_function_call = delta.delta.function_call # extract tool calls from response if assistant_message_tool_calls: tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) - # function_call = self._extract_response_function_call(assistant_message_function_call) - # tool_calls = [function_call] if function_call else [] + increase_tool_call(tool_calls) + + if delta_content is None or delta_content == '': + continue # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( content=delta_content, - tool_calls=tool_calls if assistant_message_tool_calls else [] ) + # reset tool calls + tool_calls = [] full_assistant_content += delta_content elif 'text' in choice: choice_text = choice.get('text', '') @@ -435,25 +492,36 @@ def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, f else: continue - # check payload indicator for completion - if finish_reason is not None: - yield create_final_llm_result_chunk( + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - finish_reason=finish_reason - ) - else: - yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=chunk_index, - message=assistant_prompt_message, - ) ) + ) chunk_index += 1 + if tools_calls: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=chunk_index, + message=AssistantPromptMessage( + tool_calls=tools_calls, + content="" + ), + ) + ) + + yield create_final_llm_result_chunk( + index=chunk_index, + message=AssistantPromptMessage(content=""), + finish_reason=finish_reason + ) + def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]) -> LLMResult: @@ -573,7 +641,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: return message_dict - def _num_tokens_from_string(self, model: str, text: str, + def _num_tokens_from_string(self, model: str, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None) -> int: """ Approximate num tokens for model with gpt2 tokenizer. @@ -583,7 +651,16 @@ def _num_tokens_from_string(self, model: str, text: str, :param tools: tools for tool calling :return: number of tokens """ - num_tokens = self._get_num_tokens_by_gpt2(text) + if isinstance(text, str): + full_text = text + else: + full_text = '' + for message_content in text: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(PromptMessageContent, message_content) + full_text += message_content.data + + num_tokens = self._get_num_tokens_by_gpt2(full_text) if tools: num_tokens += self._num_tokens_for_tools(tools) @@ -701,13 +778,13 @@ def _extract_response_tool_calls(self, if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call["function"]["name"], - arguments=response_tool_call["function"]["arguments"] + name=response_tool_call.get("function", {}).get("name", ""), + arguments=response_tool_call.get("function", {}).get("arguments", "") ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call["id"], - type=response_tool_call["type"], + id=response_tool_call.get("id", ""), + type=response_tool_call.get("type", ""), function=function ) tool_calls.append(tool_call) @@ -725,12 +802,12 @@ def _extract_response_function_call(self, response_function_call) \ tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call['name'], - arguments=response_function_call['arguments'] + name=response_function_call.get('name', ''), + arguments=response_function_call.get('arguments', '') ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call['name'], + id=response_function_call.get('id', ''), type="function", function=function ) diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml index be99f7684ce3a..69bed9603902a 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml +++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml @@ -86,14 +86,51 @@ model_credential_schema: default: no_call options: - value: function_call + label: + en_US: Function Call + zh_Hans: Function Call + - value: tool_call + label: + en_US: Tool Call + zh_Hans: Tool Call + - value: no_call + label: + en_US: Not Support + zh_Hans: 不支持 + - variable: stream_function_calling + show_on: + - variable: __model_type + value: llm + label: + en_US: Stream function calling + type: select + required: false + default: not_supported + options: + - value: supported label: en_US: Support zh_Hans: 支持 -# - value: tool_call -# label: -# en_US: Tool Call -# zh_Hans: Tool Call - - value: no_call + - value: not_supported + label: + en_US: Not Support + zh_Hans: 不支持 + - variable: vision_support + show_on: + - variable: __model_type + value: llm + label: + zh_Hans: Vision 支持 + en_US: Vision Support + type: select + required: false + default: no_support + options: + - value: support + label: + en_US: Support + zh_Hans: 支持 + - value: no_support label: en_US: Not Support zh_Hans: 不支持 diff --git a/api/core/model_runtime/model_providers/tongyi/llm/_client.py b/api/core/model_runtime/model_providers/tongyi/llm/_client.py deleted file mode 100644 index cfe33558e1958..0000000000000 --- a/api/core/model_runtime/model_providers/tongyi/llm/_client.py +++ /dev/null @@ -1,82 +0,0 @@ -from typing import Any, Optional - -from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.llms import Tongyi -from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry -from langchain.schema import Generation, LLMResult - - -class EnhanceTongyi(Tongyi): - @property - def _default_params(self) -> dict[str, Any]: - """Get the default parameters for calling OpenAI API.""" - normal_params = { - "top_p": self.top_p, - "api_key": self.dashscope_api_key - } - - return {**normal_params, **self.model_kwargs} - - def _generate( - self, - prompts: list[str], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> LLMResult: - generations = [] - params: dict[str, Any] = { - **{"model": self.model_name}, - **self._default_params, - **kwargs, - } - if self.streaming: - if len(prompts) > 1: - raise ValueError("Cannot stream results with multiple prompts.") - params["stream"] = True - text = '' - for stream_resp in stream_generate_with_retry( - self, prompt=prompts[0], **params - ): - if not generations: - current_text = stream_resp["output"]["text"] - else: - current_text = stream_resp["output"]["text"][len(text):] - - text = stream_resp["output"]["text"] - - generations.append( - [ - Generation( - text=current_text, - generation_info=dict( - finish_reason=stream_resp["output"]["finish_reason"], - ), - ) - ] - ) - - if run_manager: - run_manager.on_llm_new_token( - current_text, - verbose=self.verbose, - logprobs=None, - ) - else: - for prompt in prompts: - completion = generate_with_retry( - self, - prompt=prompt, - **params, - ) - generations.append( - [ - Generation( - text=completion["output"]["text"], - generation_info=dict( - finish_reason=completion["output"]["finish_reason"], - ), - ) - ] - ) - return LLMResult(generations=generations) diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index 405f93498ef9f..3d0a80144c6dd 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -1,8 +1,13 @@ +import base64 +import os +import tempfile +import uuid from collections.abc import Generator -from typing import Optional, Union +from http import HTTPStatus +from typing import Optional, Union, cast -from dashscope import get_tokenizer -from dashscope.api_entities.dashscope_response import DashScopeAPIResponse +from dashscope import Generation, MultiModalConversation, get_tokenizer +from dashscope.api_entities.dashscope_response import GenerationResponse from dashscope.common.error import ( AuthenticationError, InvalidParameter, @@ -11,17 +16,21 @@ UnsupportedHTTPMethod, UnsupportedModel, ) -from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, + ImagePromptMessageContent, PromptMessage, + PromptMessageContentType, PromptMessageTool, SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, UserPromptMessage, ) +from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -33,10 +42,9 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from ._client import EnhanceTongyi - class TongyiLargeLanguageModel(LargeLanguageModel): + tokenizers = {} def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, @@ -57,13 +65,13 @@ def _invoke(self, model: str, credentials: dict, :return: full response or stream response chunk generator result """ # invoke model - return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) - - def _code_block_mode_wrapper(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def _code_block_mode_wrapper(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None, callbacks: list[Callback] = None) \ - -> LLMResult | Generator: + -> LLMResult | Generator: """ Wrapper for code block mode """ @@ -88,7 +96,7 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, stream=stream, user=user ) - + model_parameters.pop("response_format") stop = stop or [] stop.extend(["\n```", "```\n"]) @@ -99,13 +107,13 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, # override the system message prompt_messages[0] = SystemPromptMessage( content=block_prompts - .replace("{{instructions}}", prompt_messages[0].content) + .replace("{{instructions}}", prompt_messages[0].content) ) else: # insert the system message prompt_messages.insert(0, SystemPromptMessage( content=block_prompts - .replace("{{instructions}}", f"Please output a valid {code_block} object.") + .replace("{{instructions}}", f"Please output a valid {code_block} object.") )) mode = self.get_model_mode(model, credentials) @@ -138,7 +146,7 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages=prompt_messages, input_generator=response ) - + return response def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], @@ -152,7 +160,14 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr :param tools: tools for tool calling :return: """ - tokenizer = get_tokenizer(model) + if model in ['qwen-turbo-chat', 'qwen-plus-chat']: + model = model.replace('-chat', '') + + if model in self.tokenizers: + tokenizer = self.tokenizers[model] + else: + tokenizer = get_tokenizer(model) + self.tokenizers[model] = tokenizer # convert string to token ids tokens = tokenizer.encode(self._convert_messages_to_prompt(prompt_messages)) @@ -184,6 +199,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ @@ -192,24 +208,27 @@ def _generate(self, model: str, credentials: dict, :param model: model name :param credentials: credentials :param prompt_messages: prompt messages + :param tools: tools for tool calling :param model_parameters: model parameters :param stop: stop words :param stream: is stream response :param user: unique user id :return: full response or stream response chunk generator result """ - extra_model_kwargs = {} - if stop: - extra_model_kwargs['stop'] = stop - # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) - client = EnhanceTongyi( - model_name=model, - streaming=stream, - dashscope_api_key=credentials_kwargs['api_key'], - ) + mode = self.get_model_mode(model, credentials) + + if model in ['qwen-turbo-chat', 'qwen-plus-chat']: + model = model.replace('-chat', '') + + extra_model_kwargs = {} + if tools: + extra_model_kwargs['tools'] = self._convert_tools(tools) + + if stop: + extra_model_kwargs['stop'] = stop params = { 'model': model, @@ -218,30 +237,27 @@ def _generate(self, model: str, credentials: dict, **extra_model_kwargs, } - mode = self.get_model_mode(model, credentials) + model_schema = self.get_model_schema(model, credentials) + if ModelFeature.VISION in (model_schema.features or []): + params['messages'] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages, rich_content=True) - if mode == LLMMode.CHAT: - params['messages'] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages) + response = MultiModalConversation.call(**params, stream=stream) else: - params['prompt'] = self._convert_messages_to_prompt(prompt_messages) + if mode == LLMMode.CHAT: + params['messages'] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages) + else: + params['prompt'] = prompt_messages[0].content.rstrip() - if stream: - responses = stream_generate_with_retry( - client, - stream=True, - incremental_output=True, - **params - ) + response = Generation.call(**params, + result_format='message', + stream=stream) - return self._handle_generate_stream_response(model, credentials, responses, prompt_messages) + if stream: + return self._handle_generate_stream_response(model, credentials, response, prompt_messages) - response = generate_with_retry( - client, - **params, - ) return self._handle_generate_response(model, credentials, response, prompt_messages) - - def _handle_generate_response(self, model: str, credentials: dict, response: DashScopeAPIResponse, + + def _handle_generate_response(self, model: str, credentials: dict, response: GenerationResponse, prompt_messages: list[PromptMessage]) -> LLMResult: """ Handle llm response @@ -254,7 +270,7 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Das """ # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=response.output.text + content=response.output.choices[0].message.content, ) # transform usage @@ -270,32 +286,65 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Das return result - def _handle_generate_stream_response(self, model: str, credentials: dict, responses: Generator, + def _handle_generate_stream_response(self, model: str, credentials: dict, + responses: Generator[GenerationResponse, None, None], prompt_messages: list[PromptMessage]) -> Generator: """ Handle llm stream response :param model: model name :param credentials: credentials - :param response: response + :param responses: response :param prompt_messages: prompt messages :return: llm response chunk generator result """ + full_text = '' + tool_calls = [] for index, response in enumerate(responses): - resp_finish_reason = response.output.finish_reason - resp_content = response.output.text - usage = response.usage + if response.status_code != 200 and response.status_code != HTTPStatus.OK: + raise ServiceUnavailableError( + f"Failed to invoke model {model}, status code: {response.status_code}, " + f"message: {response.message}" + ) - if resp_finish_reason is None and (resp_content is None or resp_content == ''): - continue + resp_finish_reason = response.output.choices[0].finish_reason - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=resp_content if resp_content else '', - ) + if resp_finish_reason is not None and resp_finish_reason != 'null': + resp_content = response.output.choices[0].message.content + + assistant_prompt_message = AssistantPromptMessage( + content='', + ) + + if 'tool_calls' in response.output.choices[0].message: + tool_calls = response.output.choices[0].message['tool_calls'] + elif resp_content: + # special for qwen-vl + if isinstance(resp_content, list): + resp_content = resp_content[0]['text'] + + # transform assistant message to prompt message + assistant_prompt_message.content = resp_content.replace(full_text, '', 1) + + full_text = resp_content + + if tool_calls: + message_tool_calls = [] + for tool_call_obj in tool_calls: + message_tool_call = AssistantPromptMessage.ToolCall( + id=tool_call_obj['function']['name'], + type='function', + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=tool_call_obj['function']['name'], + arguments=tool_call_obj['function']['arguments'] + ) + ) + message_tool_calls.append(message_tool_call) + + assistant_prompt_message.tool_calls = message_tool_calls - if resp_finish_reason is not None: # transform usage + usage = response.usage usage = self._calc_response_usage(model, credentials, usage.input_tokens, usage.output_tokens) yield LLMResultChunk( @@ -309,6 +358,23 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon ) ) else: + resp_content = response.output.choices[0].message.content + if not resp_content: + if 'tool_calls' in response.output.choices[0].message: + tool_calls = response.output.choices[0].message['tool_calls'] + continue + + # special for qwen-vl + if isinstance(resp_content, list): + resp_content = resp_content[0]['text'] + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=resp_content.replace(full_text, '', 1), + ) + + full_text = resp_content + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -343,11 +409,20 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: content = message.content if isinstance(message, UserPromptMessage): - message_text = f"{human_prompt} {content}" + if isinstance(content, str): + message_text = f"{human_prompt} {content}" + else: + message_text = "" + for sub_message in content: + if sub_message.type == PromptMessageContentType.TEXT: + message_text = f"{human_prompt} {sub_message.data}" + break elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" elif isinstance(message, SystemPromptMessage): message_text = content + elif isinstance(message, ToolPromptMessage): + message_text = content else: raise ValueError(f"Got unknown type {message}") @@ -370,7 +445,8 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() - def _convert_prompt_messages_to_tongyi_messages(self, prompt_messages: list[PromptMessage]) -> list[dict]: + def _convert_prompt_messages_to_tongyi_messages(self, prompt_messages: list[PromptMessage], + rich_content: bool = False) -> list[dict]: """ Convert prompt messages to tongyi messages @@ -382,23 +458,118 @@ def _convert_prompt_messages_to_tongyi_messages(self, prompt_messages: list[Prom if isinstance(prompt_message, SystemPromptMessage): tongyi_messages.append({ 'role': 'system', - 'content': prompt_message.content, + 'content': prompt_message.content if not rich_content else [{"text": prompt_message.content}], }) elif isinstance(prompt_message, UserPromptMessage): - tongyi_messages.append({ - 'role': 'user', - 'content': prompt_message.content, - }) + if isinstance(prompt_message.content, str): + tongyi_messages.append({ + 'role': 'user', + 'content': prompt_message.content if not rich_content else [{"text": prompt_message.content}], + }) + else: + sub_messages = [] + for message_content in prompt_message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(TextPromptMessageContent, message_content) + sub_message_dict = { + "text": message_content.data + } + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + + image_url = message_content.data + if message_content.data.startswith("data:"): + # convert image base64 data to file in /tmp + image_url = self._save_base64_image_to_file(message_content.data) + + sub_message_dict = { + "image": image_url + } + sub_messages.append(sub_message_dict) + + # resort sub_messages to ensure text is always at last + sub_messages = sorted(sub_messages, key=lambda x: 'text' in x) + + tongyi_messages.append({ + 'role': 'user', + 'content': sub_messages + }) elif isinstance(prompt_message, AssistantPromptMessage): + content = prompt_message.content + if not content: + content = ' ' tongyi_messages.append({ 'role': 'assistant', - 'content': prompt_message.content, + 'content': content if not rich_content else [{"text": content}], + }) + elif isinstance(prompt_message, ToolPromptMessage): + tongyi_messages.append({ + "role": "tool", + "content": prompt_message.content, + "name": prompt_message.tool_call_id }) else: raise ValueError(f"Got unknown type {prompt_message}") return tongyi_messages + def _save_base64_image_to_file(self, base64_image: str) -> str: + """ + Save base64 image to file + 'data:{upload_file.mime_type};base64,{encoded_string}' + + :param base64_image: base64 image data + :return: image file path + """ + # get mime type and encoded string + mime_type, encoded_string = base64_image.split(',')[0].split(';')[0].split(':')[1], base64_image.split(',')[1] + + # save image to file + temp_dir = tempfile.gettempdir() + + file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.{mime_type.split('/')[1]}") + + with open(file_path, "wb") as image_file: + image_file.write(base64.b64decode(encoded_string)) + + return f"file://{file_path}" + + def _convert_tools(self, tools: list[PromptMessageTool]) -> list[dict]: + """ + Convert tools + """ + tool_definitions = [] + for tool in tools: + properties = tool.parameters['properties'] + required_properties = tool.parameters['required'] + + properties_definitions = {} + for p_key, p_val in properties.items(): + desc = p_val['description'] + if 'enum' in p_val: + desc += (f"; Only accepts one of the following predefined options: " + f"[{', '.join(p_val['enum'])}]") + + properties_definitions[p_key] = { + 'description': desc, + 'type': p_val['type'], + } + + tool_definition = { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": properties_definitions, + "required": required_properties + } + } + + tool_definitions.append(tool_definition) + + return tool_definitions + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-0403.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-0403.yaml new file mode 100644 index 0000000000000..865c0c8138688 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-0403.yaml @@ -0,0 +1,81 @@ +model: qwen-max-0403 +label: + en_US: qwen-max-0403 +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 2000 + min: 1 + max: 2000 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: enable_search + type: boolean + default: false + help: + zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 + en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. + - name: response_format + use_template: response_format +pricing: + input: '0.12' + output: '0.12' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-1201.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-1201.yaml index 691347e7016a8..533d99aa55dff 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-1201.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-1201.yaml @@ -2,6 +2,10 @@ model: qwen-max-1201 label: en_US: qwen-max-1201 model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call model_properties: mode: chat context_size: 8192 @@ -9,7 +13,7 @@ parameter_rules: - name: temperature use_template: temperature type: float - default: 0.85 + default: 0.3 min: 0.0 max: 2.0 help: diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-longcontext.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-longcontext.yaml index 91129d37dd062..dbe3ece3967f5 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-longcontext.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-longcontext.yaml @@ -2,6 +2,10 @@ model: qwen-max-longcontext label: en_US: qwen-max-longcontext model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call model_properties: mode: chat context_size: 32768 @@ -9,7 +13,7 @@ parameter_rules: - name: temperature use_template: temperature type: float - default: 0.85 + default: 0.3 min: 0.0 max: 2.0 help: diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max.yaml index 5d6b69f21f71b..9a0f1afc03038 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max.yaml @@ -2,6 +2,10 @@ model: qwen-max label: en_US: qwen-max model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call model_properties: mode: chat context_size: 8192 @@ -9,7 +13,7 @@ parameter_rules: - name: temperature use_template: temperature type: float - default: 0.85 + default: 0.3 min: 0.0 max: 2.0 help: diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-chat.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-chat.yaml new file mode 100644 index 0000000000000..ae3ec0fc040a2 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-chat.yaml @@ -0,0 +1,81 @@ +model: qwen-plus-chat +label: + en_US: qwen-plus-chat +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 1500 + min: 1 + max: 1500 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: enable_search + type: boolean + default: false + help: + zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 + en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. + - name: response_format + use_template: response_format +pricing: + input: '0.02' + output: '0.02' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus.yaml index 7c25e8802b8b5..bfa04792a0c64 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus.yaml @@ -2,6 +2,8 @@ model: qwen-plus label: en_US: qwen-plus model_type: llm +features: + - agent-thought model_properties: mode: completion context_size: 32768 @@ -9,7 +11,7 @@ parameter_rules: - name: temperature use_template: temperature type: float - default: 0.85 + default: 0.3 min: 0.0 max: 2.0 help: diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo-chat.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo-chat.yaml new file mode 100644 index 0000000000000..dc8208fac62b4 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo-chat.yaml @@ -0,0 +1,81 @@ +model: qwen-turbo-chat +label: + en_US: qwen-turbo-chat +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 1500 + min: 1 + max: 1500 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: enable_search + type: boolean + default: false + help: + zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 + en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. + - name: response_format + use_template: response_format +pricing: + input: '0.008' + output: '0.008' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo.yaml index 20b46de6f3e1c..140dc68af8e20 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo.yaml @@ -2,6 +2,8 @@ model: qwen-turbo label: en_US: qwen-turbo model_type: llm +features: + - agent-thought model_properties: mode: completion context_size: 8192 @@ -9,7 +11,7 @@ parameter_rules: - name: temperature use_template: temperature type: float - default: 0.85 + default: 0.3 min: 0.0 max: 2.0 help: diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max.yaml new file mode 100644 index 0000000000000..f917ccaa5d857 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max.yaml @@ -0,0 +1,47 @@ +model: qwen-vl-max +label: + en_US: qwen-vl-max +model_type: llm +features: + - vision + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: response_format + use_template: response_format +pricing: + input: '0.02' + output: '0.02' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus.yaml new file mode 100644 index 0000000000000..e2dd8c4e576a2 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus.yaml @@ -0,0 +1,47 @@ +model: qwen-vl-plus +label: + en_US: qwen-vl-plus +model_type: llm +features: + - vision + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: response_format + use_template: response_format +pricing: + input: '0.008' + output: '0.008' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py index a5f3660fb265c..c207ffc1e34bb 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py @@ -37,8 +37,11 @@ def _invoke( :return: embeddings result """ credentials_kwargs = self._to_credential_kwargs(credentials) - dashscope.api_key = credentials_kwargs["dashscope_api_key"] - embeddings, embedding_used_tokens = self.embed_documents(model, texts) + embeddings, embedding_used_tokens = self.embed_documents( + credentials_kwargs=credentials_kwargs, + model=model, + texts=texts + ) return TextEmbeddingResult( embeddings=embeddings, @@ -74,17 +77,19 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) - dashscope.api_key = credentials_kwargs["dashscope_api_key"] + # call embedding model - self.embed_documents(model=model, texts=["ping"]) + self.embed_documents(credentials_kwargs=credentials_kwargs, model=model, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @staticmethod - def embed_documents(model: str, texts: list[str]) -> tuple[list[list[float]], int]: + def embed_documents(credentials_kwargs: dict, model: str, texts: list[str]) -> tuple[list[list[float]], int]: """Call out to Tongyi's embedding endpoint. Args: + credentials_kwargs: The credentials to use for the call. + model: The model to use for embedding. texts: The list of texts to embed. Returns: @@ -93,7 +98,12 @@ def embed_documents(model: str, texts: list[str]) -> tuple[list[list[float]], in embeddings = [] embedding_used_tokens = 0 for text in texts: - response = dashscope.TextEmbedding.call(model=model, input=text, text_type="document") + response = dashscope.TextEmbedding.call( + api_key=credentials_kwargs["dashscope_api_key"], + model=model, + input=text, + text_type="document" + ) data = response.output["embeddings"][0] embeddings.append(data["embedding"]) embedding_used_tokens += response.usage["total_tokens"] diff --git a/api/core/model_runtime/model_providers/tongyi/tts/tts.py b/api/core/model_runtime/model_providers/tongyi/tts/tts.py index 937f469bdfab4..b00f7c7c93778 100644 --- a/api/core/model_runtime/model_providers/tongyi/tts/tts.py +++ b/api/core/model_runtime/model_providers/tongyi/tts/tts.py @@ -118,7 +118,6 @@ def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, c :param content_text: text content to be translated :return: text translated to audio file """ - dashscope.api_key = credentials.get('dashscope_api_key') word_limit = self._get_model_word_limit(model, credentials) audio_type = self._get_model_audio_type(model, credentials) tts_file_id = self._get_file_name(content_text) @@ -127,6 +126,7 @@ def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, c sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) for sentence in sentences: response = dashscope.audio.tts.SpeechSynthesizer.call(model=voice, sample_rate=48000, + api_key=credentials.get('dashscope_api_key'), text=sentence.strip(), format=audio_type, word_timestamp_enabled=True, phoneme_timestamp_enabled=True) @@ -146,8 +146,8 @@ def _process_sentence(sentence: str, credentials: dict, voice: str, audio_type: :param audio_type: audio file type :return: text translated to audio file """ - dashscope.api_key = credentials.get('dashscope_api_key') response = dashscope.audio.tts.SpeechSynthesizer.call(model=voice, sample_rate=48000, + api_key=credentials.get('dashscope_api_key'), text=sentence.strip(), format=audio_type) if isinstance(response.get_audio_data(), bytes): diff --git a/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.yaml b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.yaml index 50a804743d484..ca2fad33addc6 100644 --- a/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.yaml +++ b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.yaml @@ -43,7 +43,7 @@ model_credential_schema: placeholder: zh_Hans: 在此输入您的上下文大小 en_US: Enter the context size - default: 2048 + default: '2048' - variable: completion_type label: zh_Hans: 补全类型 @@ -69,16 +69,16 @@ model_credential_schema: en_US: Stream output type: select required: true - default: true + default: 'true' placeholder: zh_Hans: 是否支持流式输出 en_US: Whether to support stream output options: - label: zh_Hans: 是 - en_US: Yes - value: true + en_US: 'Yes' + value: 'true' - label: zh_Hans: 否 - en_US: No - value: false + en_US: 'No' + value: 'false' diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-4k-0205.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-4k-0205.yaml index 7fea3872b15e2..9487342a1d2c5 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-4k-0205.yaml +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-4k-0205.yaml @@ -1,6 +1,6 @@ -model: ernie-3.5-8k +model: ernie-3.5-4k-0205 label: - en_US: Ernie-3.5-8K + en_US: Ernie-3.5-4k-0205 model_type: llm features: - agent-thought diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 79967d9004a8d..cd744383375da 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -232,8 +232,8 @@ def _get_completion_model_prompt_messages(self, app_mode: AppMode, ) ), max_token_limit=rest_tokens, - ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', - human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' + human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', + ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' ) # get prompt diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 46478ec1316eb..0f4cbccff798c 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -24,56 +24,67 @@ def __init__(self, dataset: Dataset): self._config = KeywordTableConfig() def create(self, texts: list[Document], **kwargs) -> BaseKeyword: - keyword_table_handler = JiebaKeywordTableHandler() - keyword_table = self._get_dataset_keyword_table() - for text in texts: - keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) - self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) + lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + with redis_client.lock(lock_name, timeout=600): + keyword_table_handler = JiebaKeywordTableHandler() + keyword_table = self._get_dataset_keyword_table() + for text in texts: + keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) + self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) + keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) - self._save_dataset_keyword_table(keyword_table) + self._save_dataset_keyword_table(keyword_table) - return self + return self def add_texts(self, texts: list[Document], **kwargs): - keyword_table_handler = JiebaKeywordTableHandler() - - keyword_table = self._get_dataset_keyword_table() - keywords_list = kwargs.get('keywords_list', None) - for i in range(len(texts)): - text = texts[i] - if keywords_list: - keywords = keywords_list[i] - else: - keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) - self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) - - self._save_dataset_keyword_table(keyword_table) + lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + with redis_client.lock(lock_name, timeout=600): + keyword_table_handler = JiebaKeywordTableHandler() + + keyword_table = self._get_dataset_keyword_table() + keywords_list = kwargs.get('keywords_list', None) + for i in range(len(texts)): + text = texts[i] + if keywords_list: + keywords = keywords_list[i] + if not keywords: + keywords = keyword_table_handler.extract_keywords(text.page_content, + self._config.max_keywords_per_chunk) + else: + keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) + self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) + keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) + + self._save_dataset_keyword_table(keyword_table) def text_exists(self, id: str) -> bool: keyword_table = self._get_dataset_keyword_table() return id in set.union(*keyword_table.values()) def delete_by_ids(self, ids: list[str]) -> None: - keyword_table = self._get_dataset_keyword_table() - keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) + lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + with redis_client.lock(lock_name, timeout=600): + keyword_table = self._get_dataset_keyword_table() + keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) - self._save_dataset_keyword_table(keyword_table) + self._save_dataset_keyword_table(keyword_table) def delete_by_document_id(self, document_id: str): - # get segment ids by document_id - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == self.dataset.id, - DocumentSegment.document_id == document_id - ).all() + lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + with redis_client.lock(lock_name, timeout=600): + # get segment ids by document_id + segments = db.session.query(DocumentSegment).filter( + DocumentSegment.dataset_id == self.dataset.id, + DocumentSegment.document_id == document_id + ).all() - ids = [segment.index_node_id for segment in segments] + ids = [segment.index_node_id for segment in segments] - keyword_table = self._get_dataset_keyword_table() - keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) + keyword_table = self._get_dataset_keyword_table() + keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) - self._save_dataset_keyword_table(keyword_table) + self._save_dataset_keyword_table(keyword_table) def search( self, query: str, @@ -106,13 +117,15 @@ def search( return documents def delete(self) -> None: - dataset_keyword_table = self.dataset.dataset_keyword_table - if dataset_keyword_table: - db.session.delete(dataset_keyword_table) - db.session.commit() - if dataset_keyword_table.data_source_type != 'database': - file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt' - storage.delete(file_key) + lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + with redis_client.lock(lock_name, timeout=600): + dataset_keyword_table = self.dataset.dataset_keyword_table + if dataset_keyword_table: + db.session.delete(dataset_keyword_table) + db.session.commit() + if dataset_keyword_table.data_source_type != 'database': + file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt' + storage.delete(file_key) def _save_dataset_keyword_table(self, keyword_table): keyword_table_dict = { @@ -135,33 +148,31 @@ def _save_dataset_keyword_table(self, keyword_table): storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode('utf-8')) def _get_dataset_keyword_table(self) -> Optional[dict]: - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) - with redis_client.lock(lock_name, timeout=20): - dataset_keyword_table = self.dataset.dataset_keyword_table - if dataset_keyword_table: - keyword_table_dict = dataset_keyword_table.keyword_table_dict - if keyword_table_dict: - return keyword_table_dict['__data__']['table'] - else: - keyword_data_source_type = current_app.config['KEYWORD_DATA_SOURCE_TYPE'] - dataset_keyword_table = DatasetKeywordTable( - dataset_id=self.dataset.id, - keyword_table='', - data_source_type=keyword_data_source_type, - ) - if keyword_data_source_type == 'database': - dataset_keyword_table.keyword_table = json.dumps({ - '__type__': 'keyword_table', - '__data__': { - "index_id": self.dataset.id, - "summary": None, - "table": {} - } - }, cls=SetEncoder) - db.session.add(dataset_keyword_table) - db.session.commit() + dataset_keyword_table = self.dataset.dataset_keyword_table + if dataset_keyword_table: + keyword_table_dict = dataset_keyword_table.keyword_table_dict + if keyword_table_dict: + return keyword_table_dict['__data__']['table'] + else: + keyword_data_source_type = current_app.config['KEYWORD_DATA_SOURCE_TYPE'] + dataset_keyword_table = DatasetKeywordTable( + dataset_id=self.dataset.id, + keyword_table='', + data_source_type=keyword_data_source_type, + ) + if keyword_data_source_type == 'database': + dataset_keyword_table.keyword_table = json.dumps({ + '__type__': 'keyword_table', + '__data__': { + "index_id": self.dataset.id, + "summary": None, + "table": {} + } + }, cls=SetEncoder) + db.session.add(dataset_keyword_table) + db.session.commit() - return {} + return {} def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict: for keyword in keywords: diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index dcb37ccbe605e..29bb467acfe03 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -20,16 +20,17 @@ class MilvusConfig(BaseModel): password: str secure: bool = False batch_size: int = 100 + database: str = "default" @root_validator() def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values.get('host'): raise ValueError("config MILVUS_HOST is required") - if not values['port']: + if not values.get('port'): raise ValueError("config MILVUS_PORT is required") - if not values['user']: + if not values.get('user'): raise ValueError("config MILVUS_USER is required") - if not values['password']: + if not values.get('password'): raise ValueError("config MILVUS_PASSWORD is required") return values @@ -39,7 +40,8 @@ def to_milvus_params(self): 'port': self.port, 'user': self.user, 'password': self.password, - 'secure': self.secure + 'secure': self.secure, + 'db_name': self.database, } @@ -128,7 +130,8 @@ def delete(self) -> None: uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) else: uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) - connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password) + connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password, + db_name=self._client_config.database) from pymilvus import utility if utility.has_collection(self._collection_name, using=alias): @@ -140,7 +143,8 @@ def text_exists(self, id: str) -> bool: uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) else: uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) - connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password) + connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password, + db_name=self._client_config.database) from pymilvus import utility if not utility.has_collection(self._collection_name, using=alias): @@ -192,7 +196,7 @@ def create_collection( else: uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) connections.connect(alias=alias, uri=uri, user=self._client_config.user, - password=self._client_config.password) + password=self._client_config.password, db_name=self._client_config.database) if not utility.has_collection(self._collection_name, using=alias): from pymilvus import CollectionSchema, DataType, FieldSchema from pymilvus.orm.types import infer_dtype_bydata diff --git a/api/core/rag/datasource/vdb/relyt/__init__.py b/api/core/rag/datasource/vdb/relyt/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py new file mode 100644 index 0000000000000..cfd97218b94d4 --- /dev/null +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -0,0 +1,169 @@ +import logging +from typing import Any + +from pgvecto_rs.sdk import PGVectoRs, Record +from pydantic import BaseModel, root_validator +from sqlalchemy import text as sql_text +from sqlalchemy.orm import Session + +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.models.document import Document +from extensions.ext_redis import redis_client + +logger = logging.getLogger(__name__) + +class RelytConfig(BaseModel): + host: str + port: int + user: str + password: str + database: str + + @root_validator() + def validate_config(cls, values: dict) -> dict: + if not values['host']: + raise ValueError("config RELYT_HOST is required") + if not values['port']: + raise ValueError("config RELYT_PORT is required") + if not values['user']: + raise ValueError("config RELYT_USER is required") + if not values['password']: + raise ValueError("config RELYT_PASSWORD is required") + if not values['database']: + raise ValueError("config RELYT_DATABASE is required") + return values + + +class RelytVector(BaseVector): + + def __init__(self, collection_name: str, config: RelytConfig, dim: int): + super().__init__(collection_name) + self._client_config = config + self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + self._client = PGVectoRs( + db_url=self._url, + collection_name=self._collection_name, + dimension=dim + ) + self._fields = [] + + def get_type(self) -> str: + return 'relyt' + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + index_params = {} + metadatas = [d.metadata for d in texts] + self.create_collection(len(embeddings[0])) + self.add_texts(texts, embeddings) + + def create_collection(self, dimension: int): + lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + if redis_client.get(collection_exist_cache_key): + return + index_name = f"{self._collection_name}_embedding_index" + with Session(self._client._engine) as session: + drop_statement = sql_text(f"DROP TABLE IF EXISTS collection_{self._collection_name}") + session.execute(drop_statement) + create_statement = sql_text(f""" + CREATE TABLE IF NOT EXISTS collection_{self._collection_name} ( + id UUID PRIMARY KEY, + text TEXT NOT NULL, + meta JSONB NOT NULL, + embedding vector({dimension}) NOT NULL + ) using heap; + """) + session.execute(create_statement) + index_statement = sql_text(f""" + CREATE INDEX {index_name} + ON collection_{self._collection_name} USING vectors(embedding vector_l2_ops) + WITH (options = $$ + optimizing.optimizing_threads = 30 + segment.max_growing_segment_size = 2000 + segment.max_sealed_segment_size = 30000000 + [indexing.hnsw] + m=30 + ef_construction=500 + $$); + """) + session.execute(index_statement) + session.commit() + redis_client.set(collection_exist_cache_key, 1, ex=3600) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + records = [Record.from_text(d.page_content, e, d.metadata) for d, e in zip(documents, embeddings)] + pks = [str(r.id) for r in records] + self._client.insert(records) + return pks + + def delete_by_document_id(self, document_id: str): + ids = self.get_ids_by_metadata_field('document_id', document_id) + if ids: + self._client.delete_by_ids(ids) + + def get_ids_by_metadata_field(self, key: str, value: str): + result = None + with Session(self._client._engine) as session: + select_statement = sql_text( + f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'{key}' = '{value}'; " + ) + result = session.execute(select_statement).fetchall() + if result: + return [item[0] for item in result] + else: + return None + + def delete_by_metadata_field(self, key: str, value: str): + + ids = self.get_ids_by_metadata_field(key, value) + if ids: + self._client.delete_by_ids(ids) + + def delete_by_ids(self, doc_ids: list[str]) -> None: + with Session(self._client._engine) as session: + select_statement = sql_text( + f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'doc_id' in ('{doc_ids}'); " + ) + result = session.execute(select_statement).fetchall() + if result: + ids = [item[0] for item in result] + self._client.delete_by_ids(ids) + + def delete(self) -> None: + with Session(self._client._engine) as session: + session.execute(sql_text(f"DROP TABLE IF EXISTS collection_{self._collection_name}")) + session.commit() + + def text_exists(self, id: str) -> bool: + with Session(self._client._engine) as session: + select_statement = sql_text( + f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; " + ) + result = session.execute(select_statement).fetchall() + return len(result) > 0 + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + from pgvecto_rs.sdk import filters + filter_condition = filters.meta_contains(kwargs.get('filter')) + results = self._client.search( + top_k=int(kwargs.get('top_k')), + embedding=query_vector, + filter=filter_condition + ) + + # Organize results. + docs = [] + for record, dis in results: + metadata = record.meta + metadata['score'] = dis + score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 + if dis > score_threshold: + doc = Document(page_content=record.text, + metadata=metadata) + docs.append(doc) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + # milvus/zilliz/relyt doesn't support bm25 search + return [] diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 71fc07967c7d5..b6ec7a11fb2c0 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -110,8 +110,34 @@ def _init_vector(self) -> BaseVector: user=config.get('MILVUS_USER'), password=config.get('MILVUS_PASSWORD'), secure=config.get('MILVUS_SECURE'), + database=config.get('MILVUS_DATABASE'), ) ) + elif vector_type == "relyt": + from core.rag.datasource.vdb.relyt.relyt_vector import RelytConfig, RelytVector + if self._dataset.index_struct_dict: + class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] + collection_name = class_prefix + else: + dataset_id = self._dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + index_struct_dict = { + "type": 'relyt', + "vector_store": {"class_prefix": collection_name} + } + self._dataset.index_struct = json.dumps(index_struct_dict) + dim = len(self._embeddings.embed_query("hello relyt")) + return RelytVector( + collection_name=collection_name, + config=RelytConfig( + host=config.get('RELYT_HOST'), + port=config.get('RELYT_PORT'), + user=config.get('RELYT_USER'), + password=config.get('RELYT_PASSWORD'), + database=config.get('RELYT_DATABASE'), + ), + dim=dim + ) else: raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") diff --git a/api/core/rag/extractor/blod/blod.py b/api/core/rag/extractor/blod/blod.py index 368946b5e41d1..8d423e1b3f623 100644 --- a/api/core/rag/extractor/blod/blod.py +++ b/api/core/rag/extractor/blod/blod.py @@ -159,7 +159,7 @@ class BlobLoader(ABC): def yield_blobs( self, ) -> Iterable[Blob]: - """A lazy loader for raw data represented by LangChain's Blob object. + """A lazy loader for raw data represented by Blob object. Returns: A generator over blobs diff --git a/api/core/rag/extractor/csv_extractor.py b/api/core/rag/extractor/csv_extractor.py index 059bee5f6caa1..09a1cddd1eb46 100644 --- a/api/core/rag/extractor/csv_extractor.py +++ b/api/core/rag/extractor/csv_extractor.py @@ -34,6 +34,7 @@ def __init__( def extract(self) -> list[Document]: """Load data into document objects.""" + docs = [] try: with open(self._file_path, newline="", encoding=self._encoding) as csvfile: docs = self._read_from_file(csvfile) diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index 0a964bdb013ad..2b0066448ee1c 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -2,6 +2,7 @@ from typing import Optional import pandas as pd +import xlrd from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -27,10 +28,37 @@ def __init__( self._autodetect_encoding = autodetect_encoding def extract(self) -> list[Document]: + """ parse excel file""" + if self._file_path.endswith('.xls'): + return self._extract4xls() + elif self._file_path.endswith('.xlsx'): + return self._extract4xlsx() + + def _extract4xls(self) -> list[Document]: + wb = xlrd.open_workbook(filename=self._file_path) + documents = [] + # loop over all sheets + for sheet in wb.sheets(): + for row_index, row in enumerate(sheet.get_rows(), start=1): + row_header = None + if self.is_blank_row(row): + continue + if row_header is None: + row_header = row + continue + item_arr = [] + for index, cell in enumerate(row): + txt_value = str(cell.value) + item_arr.append(f'{row_header[index].value}:{txt_value}') + item_str = "\n".join(item_arr) + document = Document(page_content=item_str, metadata={'source': self._file_path}) + documents.append(document) + return documents + + def _extract4xlsx(self) -> list[Document]: """Load from file path using Pandas.""" data = [] - - # 使用 Pandas 读取 Excel 文件的每个工作表 + # Read each worksheet of an Excel file using Pandas xls = pd.ExcelFile(self._file_path) for sheet_name in xls.sheet_names: df = pd.read_excel(xls, sheet_name=sheet_name) @@ -43,5 +71,18 @@ def extract(self) -> list[Document]: item = ';'.join(f'{k}:{v}' for k, v in row.items() if pd.notna(v)) document = Document(page_content=item, metadata={'source': self._file_path}) data.append(document) - return data + + @staticmethod + def is_blank_row(row): + """ + + Determine whether the specified line is a blank line. + :param row: row object。 + :return: Returns True if the row is blank, False otherwise. + """ + # Iterates through the cells and returns False if a non-empty cell is found + for cell in row: + if cell.value is not None and cell.value != '': + return False + return True diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 0de706533523e..1136e11f76524 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -16,6 +16,7 @@ from core.rag.extractor.text_extractor import TextExtractor from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor +from core.rag.extractor.unstructured.unstructured_epub_extractor import UnstructuredEpubExtractor from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor @@ -83,7 +84,7 @@ def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False, etl_type = current_app.config['ETL_TYPE'] unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL'] if etl_type == 'Unstructured': - if file_extension == '.xlsx': + if file_extension == '.xlsx' or file_extension == '.xls': extractor = ExcelExtractor(file_path) elif file_extension == '.pdf': extractor = PdfExtractor(file_path) @@ -106,12 +107,14 @@ def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False, extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url) elif file_extension == '.xml': extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url) + elif file_extension == 'epub': + extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url) else: # txt extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \ else TextExtractor(file_path, autodetect_encoding=True) else: - if file_extension == '.xlsx': + if file_extension == '.xlsx' or file_extension == '.xls': extractor = ExcelExtractor(file_path) elif file_extension == '.pdf': extractor = PdfExtractor(file_path) @@ -123,6 +126,8 @@ def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False, extractor = WordExtractor(file_path) elif file_extension == '.csv': extractor = CSVExtractor(file_path, autodetect_encoding=True) + elif file_extension == 'epub': + extractor = UnstructuredEpubExtractor(file_path) else: # txt extractor = TextExtractor(file_path, autodetect_encoding=True) diff --git a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py new file mode 100644 index 0000000000000..44cf958ea2b63 --- /dev/null +++ b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py @@ -0,0 +1,37 @@ +import logging + +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document + +logger = logging.getLogger(__name__) + + +class UnstructuredEpubExtractor(BaseExtractor): + """Load epub files. + + + Args: + file_path: Path to the file to load. + """ + + def __init__( + self, + file_path: str, + api_url: str = None, + ): + """Initialize with file path.""" + self._file_path = file_path + self._api_url = api_url + + def extract(self) -> list[Document]: + from unstructured.partition.epub import partition_epub + + elements = partition_epub(filename=self._file_path, xml_keep_tags=True) + from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) + documents = [] + for chunk in chunks: + text = chunk.text.strip() + documents.append(Document(page_content=text)) + + return documents diff --git a/api/core/rag/retrieval/agent/fake_llm.py b/api/core/rag/retrieval/agent/fake_llm.py deleted file mode 100644 index ab5152b38d8c8..0000000000000 --- a/api/core/rag/retrieval/agent/fake_llm.py +++ /dev/null @@ -1,59 +0,0 @@ -import time -from collections.abc import Mapping -from typing import Any, Optional - -from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.chat_models.base import SimpleChatModel -from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult - - -class FakeLLM(SimpleChatModel): - """Fake ChatModel for testing purposes.""" - - streaming: bool = False - """Whether to stream the results or not.""" - response: str - - @property - def _llm_type(self) -> str: - return "fake-chat-model" - - def _call( - self, - messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - """First try to lookup in queries, else return 'foo' or 'bar'.""" - return self.response - - @property - def _identifying_params(self) -> Mapping[str, Any]: - return {"response": self.response} - - def get_num_tokens(self, text: str) -> int: - return 0 - - def _generate( - self, - messages: list[BaseMessage], - stop: Optional[list[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs) - if self.streaming: - for token in output_str: - if run_manager: - run_manager.on_llm_new_token(token) - time.sleep(0.01) - - message = AIMessage(content=output_str) - generation = ChatGeneration(message=message) - llm_output = {"token_usage": { - 'prompt_tokens': 0, - 'completion_tokens': 0, - 'total_tokens': 0, - }} - return ChatResult(generations=[generation], llm_output=llm_output) diff --git a/api/core/rag/retrieval/agent/llm_chain.py b/api/core/rag/retrieval/agent/llm_chain.py deleted file mode 100644 index f2c5d4ca33042..0000000000000 --- a/api/core/rag/retrieval/agent/llm_chain.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Any, Optional - -from langchain import LLMChain as LCLLMChain -from langchain.callbacks.manager import CallbackManagerForChainRun -from langchain.schema import Generation, LLMResult -from langchain.schema.language_model import BaseLanguageModel - -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.entities.message_entities import lc_messages_to_prompt_messages -from core.model_manager import ModelInstance -from core.rag.retrieval.agent.fake_llm import FakeLLM - - -class LLMChain(LCLLMChain): - model_config: ModelConfigWithCredentialsEntity - """The language model instance to use.""" - llm: BaseLanguageModel = FakeLLM(response="") - parameters: dict[str, Any] = {} - - def generate( - self, - input_list: list[dict[str, Any]], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> LLMResult: - """Generate LLM result from inputs.""" - prompts, stop = self.prep_prompts(input_list, run_manager=run_manager) - messages = prompts[0].to_messages() - prompt_messages = lc_messages_to_prompt_messages(messages) - - model_instance = ModelInstance( - provider_model_bundle=self.model_config.provider_model_bundle, - model=self.model_config.model, - ) - - result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - stream=False, - stop=stop, - model_parameters=self.parameters - ) - - generations = [ - [Generation(text=result.message.content)] - ] - - return LLMResult(generations=generations) diff --git a/api/core/rag/retrieval/agent/multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/multi_dataset_router_agent.py deleted file mode 100644 index be24731d46a39..0000000000000 --- a/api/core/rag/retrieval/agent/multi_dataset_router_agent.py +++ /dev/null @@ -1,179 +0,0 @@ -from collections.abc import Sequence -from typing import Any, Optional, Union - -from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent -from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message -from langchain.callbacks.base import BaseCallbackManager -from langchain.callbacks.manager import Callbacks -from langchain.prompts.chat import BaseMessagePromptTemplate -from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage -from langchain.tools import BaseTool -from pydantic import root_validator - -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.entities.message_entities import lc_messages_to_prompt_messages -from core.model_manager import ModelInstance -from core.model_runtime.entities.message_entities import PromptMessageTool -from core.rag.retrieval.agent.fake_llm import FakeLLM - - -class MultiDatasetRouterAgent(OpenAIFunctionsAgent): - """ - An Multi Dataset Retrieve Agent driven by Router. - """ - model_config: ModelConfigWithCredentialsEntity - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - @root_validator - def validate_llm(cls, values: dict) -> dict: - return values - - def should_use_agent(self, query: str): - """ - return should use agent - - :param query: - :return: - """ - return True - - def plan( - self, - intermediate_steps: list[tuple[AgentAction, str]], - callbacks: Callbacks = None, - **kwargs: Any, - ) -> Union[AgentAction, AgentFinish]: - """Given input, decided what to do. - - Args: - intermediate_steps: Steps the LLM has taken to date, along with observations - **kwargs: User inputs. - - Returns: - Action specifying what tool to use. - """ - if len(self.tools) == 0: - return AgentFinish(return_values={"output": ''}, log='') - elif len(self.tools) == 1: - tool = next(iter(self.tools)) - rst = tool.run(tool_input={'query': kwargs['input']}) - # output = '' - # rst_json = json.loads(rst) - # for item in rst_json: - # output += f'{item["content"]}\n' - return AgentFinish(return_values={"output": rst}, log=rst) - - if intermediate_steps: - _, observation = intermediate_steps[-1] - return AgentFinish(return_values={"output": observation}, log=observation) - - try: - agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs) - if isinstance(agent_decision, AgentAction): - tool_inputs = agent_decision.tool_input - if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs: - tool_inputs['query'] = kwargs['input'] - agent_decision.tool_input = tool_inputs - else: - agent_decision.return_values['output'] = '' - return agent_decision - except Exception as e: - raise e - - def real_plan( - self, - intermediate_steps: list[tuple[AgentAction, str]], - callbacks: Callbacks = None, - **kwargs: Any, - ) -> Union[AgentAction, AgentFinish]: - """Given input, decided what to do. - - Args: - intermediate_steps: Steps the LLM has taken to date, along with observations - **kwargs: User inputs. - - Returns: - Action specifying what tool to use. - """ - agent_scratchpad = _format_intermediate_steps(intermediate_steps) - selected_inputs = { - k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" - } - full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) - prompt = self.prompt.format_prompt(**full_inputs) - messages = prompt.to_messages() - prompt_messages = lc_messages_to_prompt_messages(messages) - - model_instance = ModelInstance( - provider_model_bundle=self.model_config.provider_model_bundle, - model=self.model_config.model, - ) - - tools = [] - for function in self.functions: - tool = PromptMessageTool( - **function - ) - - tools.append(tool) - - result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - tools=tools, - stream=False, - model_parameters={ - 'temperature': 0.2, - 'top_p': 0.3, - 'max_tokens': 1500 - } - ) - - ai_message = AIMessage( - content=result.message.content or "", - additional_kwargs={ - 'function_call': { - 'id': result.message.tool_calls[0].id, - **result.message.tool_calls[0].function.dict() - } if result.message.tool_calls else None - } - ) - - agent_decision = _parse_ai_message(ai_message) - return agent_decision - - async def aplan( - self, - intermediate_steps: list[tuple[AgentAction, str]], - callbacks: Callbacks = None, - **kwargs: Any, - ) -> Union[AgentAction, AgentFinish]: - raise NotImplementedError() - - @classmethod - def from_llm_and_tools( - cls, - model_config: ModelConfigWithCredentialsEntity, - tools: Sequence[BaseTool], - callback_manager: Optional[BaseCallbackManager] = None, - extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, - system_message: Optional[SystemMessage] = SystemMessage( - content="You are a helpful AI assistant." - ), - **kwargs: Any, - ) -> BaseSingleActionAgent: - prompt = cls.create_prompt( - extra_prompt_messages=extra_prompt_messages, - system_message=system_message, - ) - return cls( - model_config=model_config, - llm=FakeLLM(response=''), - prompt=prompt, - tools=tools, - callback_manager=callback_manager, - **kwargs, - ) diff --git a/api/core/rag/retrieval/agent/output_parser/structured_chat.py b/api/core/rag/retrieval/agent/output_parser/structured_chat.py deleted file mode 100644 index c2d748d8f6e31..0000000000000 --- a/api/core/rag/retrieval/agent/output_parser/structured_chat.py +++ /dev/null @@ -1,29 +0,0 @@ -import json -import re -from typing import Union - -from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser as LCStructuredChatOutputParser -from langchain.agents.structured_chat.output_parser import logger -from langchain.schema import AgentAction, AgentFinish, OutputParserException - - -class StructuredChatOutputParser(LCStructuredChatOutputParser): - def parse(self, text: str) -> Union[AgentAction, AgentFinish]: - try: - action_match = re.search(r"```(\w*)\n?({.*?)```", text, re.DOTALL) - if action_match is not None: - response = json.loads(action_match.group(2).strip(), strict=False) - if isinstance(response, list): - # gpt turbo frequently ignores the directive to emit a single action - logger.warning("Got multiple action responses: %s", response) - response = response[0] - if response["action"] == "Final Answer": - return AgentFinish({"output": response["action_input"]}, text) - else: - return AgentAction( - response["action"], response.get("action_input", {}), text - ) - else: - return AgentFinish({"output": text}, text) - except Exception as e: - raise OutputParserException(f"Could not parse LLM output: {text}") diff --git a/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py b/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py deleted file mode 100644 index 7035ec8e2f783..0000000000000 --- a/api/core/rag/retrieval/agent/structed_multi_dataset_router_agent.py +++ /dev/null @@ -1,259 +0,0 @@ -import re -from collections.abc import Sequence -from typing import Any, Optional, Union, cast - -from langchain import BasePromptTemplate, PromptTemplate -from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent -from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE -from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX -from langchain.callbacks.base import BaseCallbackManager -from langchain.callbacks.manager import Callbacks -from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate -from langchain.schema import AgentAction, AgentFinish, OutputParserException -from langchain.tools import BaseTool - -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.rag.retrieval.agent.llm_chain import LLMChain - -FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). -The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. -Valid "action" values: "Final Answer" or {tool_names} - -Provide only ONE action per $JSON_BLOB, as shown: - -``` -{{{{ - "action": $TOOL_NAME, - "action_input": $INPUT -}}}} -``` - -Follow this format: - -Question: input question to answer -Thought: consider previous and subsequent steps -Action: -``` -$JSON_BLOB -``` -Observation: action result -... (repeat Thought/Action/Observation N times) -Thought: I know what to respond -Action: -``` -{{{{ - "action": "Final Answer", - "action_input": "Final response to human" -}}}} -```""" - - -class StructuredMultiDatasetRouterAgent(StructuredChatAgent): - dataset_tools: Sequence[BaseTool] - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - def should_use_agent(self, query: str): - """ - return should use agent - Using the ReACT mode to determine whether an agent is needed is costly, - so it's better to just use an Agent for reasoning, which is cheaper. - - :param query: - :return: - """ - return True - - def plan( - self, - intermediate_steps: list[tuple[AgentAction, str]], - callbacks: Callbacks = None, - **kwargs: Any, - ) -> Union[AgentAction, AgentFinish]: - """Given input, decided what to do. - - Args: - intermediate_steps: Steps the LLM has taken to date, - along with observations - callbacks: Callbacks to run. - **kwargs: User inputs. - - Returns: - Action specifying what tool to use. - """ - if len(self.dataset_tools) == 0: - return AgentFinish(return_values={"output": ''}, log='') - elif len(self.dataset_tools) == 1: - tool = next(iter(self.dataset_tools)) - rst = tool.run(tool_input={'query': kwargs['input']}) - return AgentFinish(return_values={"output": rst}, log=rst) - - if intermediate_steps: - _, observation = intermediate_steps[-1] - return AgentFinish(return_values={"output": observation}, log=observation) - - full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) - - try: - full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) - except Exception as e: - raise e - - try: - agent_decision = self.output_parser.parse(full_output) - if isinstance(agent_decision, AgentAction): - tool_inputs = agent_decision.tool_input - if isinstance(tool_inputs, dict) and 'query' in tool_inputs: - tool_inputs['query'] = kwargs['input'] - agent_decision.tool_input = tool_inputs - elif isinstance(tool_inputs, str): - agent_decision.tool_input = kwargs['input'] - else: - agent_decision.return_values['output'] = '' - return agent_decision - except OutputParserException: - return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " - "I don't know how to respond to that."}, "") - - @classmethod - def create_prompt( - cls, - tools: Sequence[BaseTool], - prefix: str = PREFIX, - suffix: str = SUFFIX, - human_message_template: str = HUMAN_MESSAGE_TEMPLATE, - format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[list[str]] = None, - memory_prompts: Optional[list[BasePromptTemplate]] = None, - ) -> BasePromptTemplate: - tool_strings = [] - for tool in tools: - args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args))) - tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}") - formatted_tools = "\n".join(tool_strings) - unique_tool_names = set(tool.name for tool in tools) - tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) - format_instructions = format_instructions.format(tool_names=tool_names) - template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) - if input_variables is None: - input_variables = ["input", "agent_scratchpad"] - _memory_prompts = memory_prompts or [] - messages = [ - SystemMessagePromptTemplate.from_template(template), - *_memory_prompts, - HumanMessagePromptTemplate.from_template(human_message_template), - ] - return ChatPromptTemplate(input_variables=input_variables, messages=messages) - - @classmethod - def create_completion_prompt( - cls, - tools: Sequence[BaseTool], - prefix: str = PREFIX, - format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[list[str]] = None, - ) -> PromptTemplate: - """Create prompt in the style of the zero shot agent. - - Args: - tools: List of tools the agent will have access to, used to format the - prompt. - prefix: String to put before the list of tools. - input_variables: List of input variables the final prompt will expect. - - Returns: - A PromptTemplate with the template assembled from the pieces here. - """ - suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. -Question: {input} -Thought: {agent_scratchpad} -""" - - tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) - tool_names = ", ".join([tool.name for tool in tools]) - format_instructions = format_instructions.format(tool_names=tool_names) - template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) - if input_variables is None: - input_variables = ["input", "agent_scratchpad"] - return PromptTemplate(template=template, input_variables=input_variables) - - def _construct_scratchpad( - self, intermediate_steps: list[tuple[AgentAction, str]] - ) -> str: - agent_scratchpad = "" - for action, observation in intermediate_steps: - agent_scratchpad += action.log - agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}" - - if not isinstance(agent_scratchpad, str): - raise ValueError("agent_scratchpad should be of type string.") - if agent_scratchpad: - llm_chain = cast(LLMChain, self.llm_chain) - if llm_chain.model_config.mode == "chat": - return ( - f"This was your previous work " - f"(but I haven't seen any of it! I only see what " - f"you return as final answer):\n{agent_scratchpad}" - ) - else: - return agent_scratchpad - else: - return agent_scratchpad - - @classmethod - def from_llm_and_tools( - cls, - model_config: ModelConfigWithCredentialsEntity, - tools: Sequence[BaseTool], - callback_manager: Optional[BaseCallbackManager] = None, - output_parser: Optional[AgentOutputParser] = None, - prefix: str = PREFIX, - suffix: str = SUFFIX, - human_message_template: str = HUMAN_MESSAGE_TEMPLATE, - format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[list[str]] = None, - memory_prompts: Optional[list[BasePromptTemplate]] = None, - **kwargs: Any, - ) -> Agent: - """Construct an agent from an LLM and tools.""" - cls._validate_tools(tools) - if model_config.mode == "chat": - prompt = cls.create_prompt( - tools, - prefix=prefix, - suffix=suffix, - human_message_template=human_message_template, - format_instructions=format_instructions, - input_variables=input_variables, - memory_prompts=memory_prompts, - ) - else: - prompt = cls.create_completion_prompt( - tools, - prefix=prefix, - format_instructions=format_instructions, - input_variables=input_variables - ) - - llm_chain = LLMChain( - model_config=model_config, - prompt=prompt, - callback_manager=callback_manager, - parameters={ - 'temperature': 0.2, - 'top_p': 0.3, - 'max_tokens': 1500 - } - ) - tool_names = [tool.name for tool in tools] - _output_parser = output_parser - return cls( - llm_chain=llm_chain, - allowed_tools=tool_names, - output_parser=_output_parser, - dataset_tools=tools, - **kwargs, - ) diff --git a/api/core/rag/retrieval/agent_based_dataset_executor.py b/api/core/rag/retrieval/agent_based_dataset_executor.py deleted file mode 100644 index cb475bcffb791..0000000000000 --- a/api/core/rag/retrieval/agent_based_dataset_executor.py +++ /dev/null @@ -1,117 +0,0 @@ -import logging -from typing import Optional, Union - -from langchain.agents import AgentExecutor as LCAgentExecutor -from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent -from langchain.callbacks.manager import Callbacks -from langchain.tools import BaseTool -from pydantic import BaseModel, Extra - -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.entities.agent_entities import PlanningStrategy -from core.entities.message_entities import prompt_messages_to_lc_messages -from core.helper import moderation -from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.errors.invoke import InvokeError -from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent -from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser -from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent -from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool -from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool - - -class AgentConfiguration(BaseModel): - strategy: PlanningStrategy - model_config: ModelConfigWithCredentialsEntity - tools: list[BaseTool] - summary_model_config: Optional[ModelConfigWithCredentialsEntity] = None - memory: Optional[TokenBufferMemory] = None - callbacks: Callbacks = None - max_iterations: int = 6 - max_execution_time: Optional[float] = None - early_stopping_method: str = "generate" - # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - -class AgentExecuteResult(BaseModel): - strategy: PlanningStrategy - output: Optional[str] - configuration: AgentConfiguration - - -class AgentExecutor: - def __init__(self, configuration: AgentConfiguration): - self.configuration = configuration - self.agent = self._init_agent() - - def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]: - if self.configuration.strategy == PlanningStrategy.ROUTER: - self.configuration.tools = [t for t in self.configuration.tools - if isinstance(t, DatasetRetrieverTool) - or isinstance(t, DatasetMultiRetrieverTool)] - agent = MultiDatasetRouterAgent.from_llm_and_tools( - model_config=self.configuration.model_config, - tools=self.configuration.tools, - extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages()) - if self.configuration.memory else None, - verbose=True - ) - elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER: - self.configuration.tools = [t for t in self.configuration.tools - if isinstance(t, DatasetRetrieverTool) - or isinstance(t, DatasetMultiRetrieverTool)] - agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools( - model_config=self.configuration.model_config, - tools=self.configuration.tools, - output_parser=StructuredChatOutputParser(), - verbose=True - ) - else: - raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}") - - return agent - - def should_use_agent(self, query: str) -> bool: - return self.agent.should_use_agent(query) - - def run(self, query: str) -> AgentExecuteResult: - moderation_result = moderation.check_moderation( - self.configuration.model_config, - query - ) - - if moderation_result: - return AgentExecuteResult( - output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.", - strategy=self.configuration.strategy, - configuration=self.configuration - ) - - agent_executor = LCAgentExecutor.from_agent_and_tools( - agent=self.agent, - tools=self.configuration.tools, - max_iterations=self.configuration.max_iterations, - max_execution_time=self.configuration.max_execution_time, - early_stopping_method=self.configuration.early_stopping_method, - callbacks=self.configuration.callbacks - ) - - try: - output = agent_executor.run(input=query) - except InvokeError as ex: - raise ex - except Exception as ex: - logging.exception("agent_executor run failed") - output = None - - return AgentExecuteResult( - output=output, - strategy=self.configuration.strategy, - configuration=self.configuration - ) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index ee728423262fc..155b8be06c0bb 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1,23 +1,43 @@ +import threading from typing import Optional, cast -from langchain.tools import BaseTool +from flask import Flask, current_app from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.model_entities import ModelFeature +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.message_entities import PromptMessageTool +from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.rag.retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.models.document import Document +from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter +from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter +from core.rerank.rerank import RerankRunner from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool +from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db -from models.dataset import Dataset +from models.dataset import Dataset, DatasetQuery, DocumentSegment +from models.dataset import Document as DatasetDocument + +default_retrieval_model = { + 'search_method': 'semantic_search', + 'reranking_enable': False, + 'reranking_model': { + 'reranking_provider_name': '', + 'reranking_model_name': '' + }, + 'top_k': 2, + 'score_threshold_enabled': False +} class DatasetRetrieval: - def retrieve(self, tenant_id: str, + def retrieve(self, app_id: str, user_id: str, tenant_id: str, model_config: ModelConfigWithCredentialsEntity, config: DatasetEntity, query: str, @@ -27,6 +47,8 @@ def retrieve(self, tenant_id: str, memory: Optional[TokenBufferMemory] = None) -> Optional[str]: """ Retrieve dataset. + :param app_id: app_id + :param user_id: user_id :param tenant_id: tenant id :param model_config: model config :param config: dataset config @@ -38,12 +60,22 @@ def retrieve(self, tenant_id: str, :return: """ dataset_ids = config.dataset_ids + if len(dataset_ids) == 0: + return None retrieve_config = config.retrieve_config # check model is support tool calling model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + provider=model_config.provider, + model=model_config.model + ) + # get model schema model_schema = model_type_instance.get_model_schema( model=model_config.model, @@ -59,38 +91,291 @@ def retrieve(self, tenant_id: str, if ModelFeature.TOOL_CALL in features \ or ModelFeature.MULTI_TOOL_CALL in features: planning_strategy = PlanningStrategy.ROUTER + available_datasets = [] + for dataset_id in dataset_ids: + # get dataset from dataset id + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == tenant_id, + Dataset.id == dataset_id + ).first() + + # pass if dataset is not available + if not dataset: + continue + + # pass if dataset is not available + if (dataset and dataset.available_document_count == 0 + and dataset.available_document_count == 0): + continue + + available_datasets.append(dataset) + all_documents = [] + user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user' + if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: + all_documents = self.single_retrieve(app_id, tenant_id, user_id, user_from, available_datasets, query, + model_instance, + model_config, planning_strategy) + elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: + all_documents = self.multiple_retrieve(app_id, tenant_id, user_id, user_from, + available_datasets, query, retrieve_config.top_k, + retrieve_config.score_threshold, + retrieve_config.reranking_model.get('reranking_provider_name'), + retrieve_config.reranking_model.get('reranking_model_name')) + + document_score_list = {} + for item in all_documents: + if 'score' in item.metadata and item.metadata['score']: + document_score_list[item.metadata['doc_id']] = item.metadata['score'] + + document_context_list = [] + index_node_ids = [document.metadata['doc_id'] for document in all_documents] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id.in_(dataset_ids), + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids) + ).all() + + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted(segments, + key=lambda segment: index_node_id_to_position.get(segment.index_node_id, + float('inf'))) + for segment in sorted_segments: + if segment.answer: + document_context_list.append(f'question:{segment.content} answer:{segment.answer}') + else: + document_context_list.append(segment.content) + if show_retrieve_source: + context_list = [] + resource_number = 1 + for segment in sorted_segments: + dataset = Dataset.query.filter_by( + id=segment.dataset_id + ).first() + document = DatasetDocument.query.filter(DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ).first() + if dataset and document: + source = { + 'position': resource_number, + 'dataset_id': dataset.id, + 'dataset_name': dataset.name, + 'document_id': document.id, + 'document_name': document.name, + 'data_source_type': document.data_source_type, + 'segment_id': segment.id, + 'retriever_from': invoke_from.to_source(), + 'score': document_score_list.get(segment.index_node_id, None) + } + + if invoke_from.to_source() == 'dev': + source['hit_count'] = segment.hit_count + source['word_count'] = segment.word_count + source['segment_position'] = segment.position + source['index_node_hash'] = segment.index_node_hash + if segment.answer: + source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + else: + source['content'] = segment.content + context_list.append(source) + resource_number += 1 + if hit_callback: + hit_callback.return_retriever_resource_info(context_list) + + return str("\n".join(document_context_list)) + return '' + + def single_retrieve(self, app_id: str, + tenant_id: str, + user_id: str, + user_from: str, + available_datasets: list, + query: str, + model_instance: ModelInstance, + model_config: ModelConfigWithCredentialsEntity, + planning_strategy: PlanningStrategy, + ): + tools = [] + for dataset in available_datasets: + description = dataset.description + if not description: + description = 'useful for when you want to answer queries about the ' + dataset.name + + description = description.replace('\n', '').replace('\r', '') + message_tool = PromptMessageTool( + name=dataset.id, + description=description, + parameters={ + "type": "object", + "properties": {}, + "required": [], + } + ) + tools.append(message_tool) + dataset_id = None + if planning_strategy == PlanningStrategy.REACT_ROUTER: + react_multi_dataset_router = ReactMultiDatasetRouter() + dataset_id = react_multi_dataset_router.invoke(query, tools, model_config, model_instance, + user_id, tenant_id) - dataset_retriever_tools = self.to_dataset_retriever_tool( + elif planning_strategy == PlanningStrategy.ROUTER: + function_call_router = FunctionCallMultiDatasetRouter() + dataset_id = function_call_router.invoke(query, tools, model_config, model_instance) + + if dataset_id: + # get retrieval model config + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() + if dataset: + retrieval_model_config = dataset.retrieval_model \ + if dataset.retrieval_model else default_retrieval_model + + # get top k + top_k = retrieval_model_config['top_k'] + # get retrieval method + if dataset.indexing_technique == "economy": + retrival_method = 'keyword_search' + else: + retrival_method = retrieval_model_config['search_method'] + # get reranking model + reranking_model = retrieval_model_config['reranking_model'] \ + if retrieval_model_config['reranking_enable'] else None + # get score threshold + score_threshold = .0 + score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") + if score_threshold_enabled: + score_threshold = retrieval_model_config.get("score_threshold") + + results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, + query=query, + top_k=top_k, score_threshold=score_threshold, + reranking_model=reranking_model) + self._on_query(query, [dataset_id], app_id, user_from, user_id) + if results: + self._on_retrival_end(results) + return results + return [] + + def multiple_retrieve(self, + app_id: str, + tenant_id: str, + user_id: str, + user_from: str, + available_datasets: list, + query: str, + top_k: int, + score_threshold: float, + reranking_provider_name: str, + reranking_model_name: str): + threads = [] + all_documents = [] + dataset_ids = [dataset.id for dataset in available_datasets] + for dataset in available_datasets: + retrieval_thread = threading.Thread(target=self._retriever, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset.id, + 'query': query, + 'top_k': top_k, + 'all_documents': all_documents, + }) + threads.append(retrieval_thread) + retrieval_thread.start() + for thread in threads: + thread.join() + # do rerank for searched documents + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( tenant_id=tenant_id, - dataset_ids=dataset_ids, - retrieve_config=retrieve_config, - return_resource=show_retrieve_source, - invoke_from=invoke_from, - hit_callback=hit_callback + provider=reranking_provider_name, + model_type=ModelType.RERANK, + model=reranking_model_name ) - if len(dataset_retriever_tools) == 0: - return None + rerank_runner = RerankRunner(rerank_model_instance) + all_documents = rerank_runner.run(query, all_documents, + score_threshold, + top_k) + self._on_query(query, dataset_ids, app_id, user_from, user_id) + if all_documents: + self._on_retrival_end(all_documents) + return all_documents - agent_configuration = AgentConfiguration( - strategy=planning_strategy, - model_config=model_config, - tools=dataset_retriever_tools, - memory=memory, - max_iterations=10, - max_execution_time=400.0, - early_stopping_method="generate" - ) + def _on_retrival_end(self, documents: list[Document]) -> None: + """Handle retrival end.""" + for document in documents: + query = db.session.query(DocumentSegment).filter( + DocumentSegment.index_node_id == document.metadata['doc_id'] + ) - agent_executor = AgentExecutor(agent_configuration) + # if 'dataset_id' in document.metadata: + if 'dataset_id' in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) - should_use_agent = agent_executor.should_use_agent(query) - if not should_use_agent: - return None + # add hit count to document segment + query.update( + {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, + synchronize_session=False + ) + + db.session.commit() + + def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None: + """ + Handle query. + """ + if not query: + return + for dataset_id in dataset_ids: + dataset_query = DatasetQuery( + dataset_id=dataset_id, + content=query, + source='app', + source_app_id=app_id, + created_by_role=user_from, + created_by=user_id + ) + db.session.add(dataset_query) + db.session.commit() + + def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): + with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() + + if not dataset: + return [] + + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model - result = agent_executor.run(query) + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve(retrival_method='keyword_search', + dataset_id=dataset.id, + query=query, + top_k=top_k + ) + if documents: + all_documents.extend(documents) + else: + if top_k > 0: + # retrieval source + documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=retrieval_model['score_threshold'] + if retrieval_model['score_threshold_enabled'] else None, + reranking_model=retrieval_model['reranking_model'] + if retrieval_model['reranking_enable'] else None + ) - return result.output + all_documents.extend(documents) def to_dataset_retriever_tool(self, tenant_id: str, dataset_ids: list[str], @@ -98,7 +383,7 @@ def to_dataset_retriever_tool(self, tenant_id: str, return_resource: bool, invoke_from: InvokeFrom, hit_callback: DatasetIndexToolCallbackHandler) \ - -> Optional[list[BaseTool]]: + -> Optional[list[DatasetRetrieverBaseTool]]: """ A dataset tool is a tool that can be used to retrieve information from a dataset :param tenant_id: tenant id diff --git a/api/core/rag/retrieval/output_parser/__init__.py b/api/core/rag/retrieval/output_parser/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/api/core/rag/retrieval/output_parser/react_output.py b/api/core/rag/retrieval/output_parser/react_output.py new file mode 100644 index 0000000000000..9a14d417164e6 --- /dev/null +++ b/api/core/rag/retrieval/output_parser/react_output.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import NamedTuple, Union + + +@dataclass +class ReactAction: + """A full description of an action for an ReactAction to execute.""" + + tool: str + """The name of the Tool to execute.""" + tool_input: Union[str, dict] + """The input to pass in to the Tool.""" + log: str + """Additional information to log about the action.""" + + +class ReactFinish(NamedTuple): + """The final return value of an ReactFinish.""" + + return_values: dict + """Dictionary of return values.""" + log: str + """Additional information to log about the return value""" diff --git a/api/core/rag/retrieval/output_parser/structured_chat.py b/api/core/rag/retrieval/output_parser/structured_chat.py new file mode 100644 index 0000000000000..60770bd4c6e06 --- /dev/null +++ b/api/core/rag/retrieval/output_parser/structured_chat.py @@ -0,0 +1,25 @@ +import json +import re +from typing import Union + +from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish + + +class StructuredChatOutputParser: + def parse(self, text: str) -> Union[ReactAction, ReactFinish]: + try: + action_match = re.search(r"```(\w*)\n?({.*?)```", text, re.DOTALL) + if action_match is not None: + response = json.loads(action_match.group(2).strip(), strict=False) + if isinstance(response, list): + response = response[0] + if response["action"] == "Final Answer": + return ReactFinish({"output": response["action_input"]}, text) + else: + return ReactAction( + response["action"], response.get("action_input", {}), text + ) + else: + return ReactFinish({"output": text}, text) + except Exception as e: + raise ValueError(f"Could not parse LLM output: {text}") diff --git a/api/core/workflow/nodes/knowledge_retrieval/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py similarity index 100% rename from api/core/workflow/nodes/knowledge_retrieval/multi_dataset_function_call_router.py rename to api/core/rag/retrieval/router/multi_dataset_function_call_router.py diff --git a/api/core/workflow/nodes/knowledge_retrieval/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py similarity index 79% rename from api/core/workflow/nodes/knowledge_retrieval/multi_dataset_react_route.py rename to api/core/rag/retrieval/router/multi_dataset_react_route.py index a2e3cd71a5a6a..5de2a66e2dacb 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -1,21 +1,21 @@ from collections.abc import Generator, Sequence -from typing import Optional, Union - -from langchain import PromptTemplate -from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE -from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX -from langchain.schema import AgentAction +from typing import Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage -from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser -from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate +from core.rag.retrieval.output_parser.react_output import ReactAction +from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser from core.workflow.nodes.llm.llm_node import LLMNode +PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" + +SUFFIX = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. +Thought:""" + FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. Valid "action" values: "Final Answer" or {tool_names} @@ -55,11 +55,10 @@ def invoke( self, query: str, dataset_tools: list[PromptMessageTool], - node_data: KnowledgeRetrievalNodeData, model_config: ModelConfigWithCredentialsEntity, model_instance: ModelInstance, user_id: str, - tenant_id: str, + tenant_id: str ) -> Union[str, None]: """Given input, decided what to do. @@ -72,7 +71,8 @@ def invoke( return dataset_tools[0].name try: - return self._react_invoke(query=query, node_data=node_data, model_config=model_config, model_instance=model_instance, + return self._react_invoke(query=query, model_config=model_config, + model_instance=model_instance, tools=dataset_tools, user_id=user_id, tenant_id=tenant_id) except Exception as e: return None @@ -80,7 +80,6 @@ def invoke( def _react_invoke( self, query: str, - node_data: KnowledgeRetrievalNodeData, model_config: ModelConfigWithCredentialsEntity, model_instance: ModelInstance, tools: Sequence[PromptMessageTool], @@ -88,7 +87,6 @@ def _react_invoke( tenant_id: str, prefix: str = PREFIX, suffix: str = SUFFIX, - human_message_template: str = HUMAN_MESSAGE_TEMPLATE, format_instructions: str = FORMAT_INSTRUCTIONS, ) -> Union[str, None]: if model_config.mode == "chat": @@ -97,7 +95,6 @@ def _react_invoke( tools=tools, prefix=prefix, suffix=suffix, - human_message_template=human_message_template, format_instructions=format_instructions, ) else: @@ -105,7 +102,6 @@ def _react_invoke( tools=tools, prefix=prefix, format_instructions=format_instructions, - input_variables=None ) stop = ['Observation:'] # handle invoke result @@ -121,7 +117,7 @@ def _react_invoke( model_config=model_config ) result_text, usage = self._invoke_llm( - node_data=node_data, + completion_param=model_config.parameters, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, @@ -129,18 +125,18 @@ def _react_invoke( tenant_id=tenant_id ) output_parser = StructuredChatOutputParser() - agent_decision = output_parser.parse(result_text) - if isinstance(agent_decision, AgentAction): - return agent_decision.tool + react_decision = output_parser.parse(result_text) + if isinstance(react_decision, ReactAction): + return react_decision.tool return None - def _invoke_llm(self, node_data: KnowledgeRetrievalNodeData, + def _invoke_llm(self, completion_param: dict, model_instance: ModelInstance, prompt_messages: list[PromptMessage], - stop: list[str], user_id: str, tenant_id: str) -> tuple[str, LLMUsage]: + stop: list[str], user_id: str, tenant_id: str + ) -> tuple[str, LLMUsage]: """ Invoke large language model - :param node_data: node data :param model_instance: model instance :param prompt_messages: prompt messages :param stop: stop @@ -148,7 +144,7 @@ def _invoke_llm(self, node_data: KnowledgeRetrievalNodeData, """ invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, - model_parameters=node_data.single_retrieval_config.model.completion_params, + model_parameters=completion_param, stop=stop, stream=True, user=user_id, @@ -198,12 +194,12 @@ def create_chat_prompt( tools: Sequence[PromptMessageTool], prefix: str = PREFIX, suffix: str = SUFFIX, - human_message_template: str = HUMAN_MESSAGE_TEMPLATE, format_instructions: str = FORMAT_INSTRUCTIONS, ) -> list[ChatModelMessage]: tool_strings = [] for tool in tools: - tool_strings.append(f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}") + tool_strings.append( + f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}") formatted_tools = "\n".join(tool_strings) unique_tool_names = set(tool.name for tool in tools) tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) @@ -227,16 +223,13 @@ def create_completion_prompt( tools: Sequence[PromptMessageTool], prefix: str = PREFIX, format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[list[str]] = None, - ) -> PromptTemplate: + ) -> CompletionModelPromptTemplate: """Create prompt in the style of the zero shot agent. Args: tools: List of tools the agent will have access to, used to format the prompt. prefix: String to put before the list of tools. - input_variables: List of input variables the final prompt will expect. - Returns: A PromptTemplate with the template assembled from the pieces here. """ @@ -249,6 +242,4 @@ def create_completion_prompt( tool_names = ", ".join([tool.name for tool in tools]) format_instructions = format_instructions.format(tool_names=tool_names) template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) - if input_variables is None: - input_variables = ["input", "agent_scratchpad"] - return PromptTemplate(template=template, input_variables=input_variables) + return CompletionModelPromptTemplate(text=template) diff --git a/api/core/tools/prompt/template.py b/api/core/tools/prompt/template.py index 3d35592279228..b0cf1a77fb177 100644 --- a/api/core/tools/prompt/template.py +++ b/api/core/tools/prompt/template.py @@ -38,8 +38,10 @@ ``` Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. +{{historic_messages}} Question: {{query}} -Thought: {{agent_scratchpad}}""" +{{agent_scratchpad}} +Thought:""" ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}} Thought:""" diff --git a/api/core/tools/provider/_position.yaml b/api/core/tools/provider/_position.yaml index 7eb40b2ab8941..414bd7e38cfc1 100644 --- a/api/core/tools/provider/_position.yaml +++ b/api/core/tools/provider/_position.yaml @@ -1,6 +1,7 @@ - google - bing - duckduckgo +- searxng - dalle - azuredalle - wikipedia @@ -12,6 +13,7 @@ - pubmed - stablediffusion - webscraper +- jina - model.zhipuai - aippt - youtube diff --git a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py index 033d942f4d44f..fb64e07a8c68f 100644 --- a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py +++ b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py @@ -1,11 +1,92 @@ -from typing import Any +import logging +from typing import Any, Optional -from langchain.utilities import ArxivAPIWrapper +import arxiv from pydantic import BaseModel, Field from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool +logger = logging.getLogger(__name__) +class ArxivAPIWrapper(BaseModel): + """Wrapper around ArxivAPI. + + To use, you should have the ``arxiv`` python package installed. + https://lukasschwab.me/arxiv.py/index.html + This wrapper will use the Arxiv API to conduct searches and + fetch document summaries. By default, it will return the document summaries + of the top-k results. + It limits the Document content by doc_content_chars_max. + Set doc_content_chars_max=None if you don't want to limit the content size. + + Args: + top_k_results: number of the top-scored document used for the arxiv tool + ARXIV_MAX_QUERY_LENGTH: the cut limit on the query used for the arxiv tool. + load_max_docs: a limit to the number of loaded documents + load_all_available_meta: + if True: the `metadata` of the loaded Documents contains all available + meta info (see https://lukasschwab.me/arxiv.py/index.html#Result), + if False: the `metadata` contains only the published date, title, + authors and summary. + doc_content_chars_max: an optional cut limit for the length of a document's + content + + Example: + .. code-block:: python + + arxiv = ArxivAPIWrapper( + top_k_results = 3, + ARXIV_MAX_QUERY_LENGTH = 300, + load_max_docs = 3, + load_all_available_meta = False, + doc_content_chars_max = 40000 + ) + arxiv.run("tree of thought llm) + """ + + arxiv_search = arxiv.Search #: :meta private: + arxiv_exceptions = ( + arxiv.ArxivError, + arxiv.UnexpectedEmptyPageError, + arxiv.HTTPError, + ) # :meta private: + top_k_results: int = 3 + ARXIV_MAX_QUERY_LENGTH = 300 + load_max_docs: int = 100 + load_all_available_meta: bool = False + doc_content_chars_max: Optional[int] = 4000 + + def run(self, query: str) -> str: + """ + Performs an arxiv search and A single string + with the publish date, title, authors, and summary + for each article separated by two newlines. + + If an error occurs or no documents found, error text + is returned instead. Wrapper for + https://lukasschwab.me/arxiv.py/index.html#Search + + Args: + query: a plaintext search query + """ # noqa: E501 + try: + results = self.arxiv_search( # type: ignore + query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results + ).results() + except self.arxiv_exceptions as ex: + return f"Arxiv exception: {ex}" + docs = [ + f"Published: {result.updated.date()}\n" + f"Title: {result.title}\n" + f"Authors: {', '.join(a.name for a in result.authors)}\n" + f"Summary: {result.summary}" + for result in results + ] + if docs: + return "\n\n".join(docs)[: self.doc_content_chars_max] + else: + return "No good Arxiv Result was found" + class ArxivSearchInput(BaseModel): query: str = Field(..., description="Search query.") diff --git a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py index 8f11d2173ca52..c51c4d567fd75 100644 --- a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py +++ b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py @@ -12,6 +12,7 @@ class BingSearchTool(BuiltinTool): def _invoke_bing(self, user_id: str, + server_url: str, subscription_key: str, query: str, limit: int, result_type: str, market: str, lang: str, filters: list[str]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: @@ -26,7 +27,7 @@ def _invoke_bing(self, } query = quote(query) - server_url = f'{self.url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}' + server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}' response = get(server_url, headers=headers) if response.status_code != 200: @@ -136,6 +137,7 @@ def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dic self._invoke_bing( user_id='test', + server_url=server_url, subscription_key=key, query=query, limit=limit, @@ -188,6 +190,7 @@ def _invoke(self, return self._invoke_bing( user_id=user_id, + server_url=server_url, subscription_key=key, query=query, limit=limit, diff --git a/api/core/tools/provider/builtin/brave/tools/brave_search.py b/api/core/tools/provider/builtin/brave/tools/brave_search.py index cb91d94994d92..f121cb0e34f72 100644 --- a/api/core/tools/provider/builtin/brave/tools/brave_search.py +++ b/api/core/tools/provider/builtin/brave/tools/brave_search.py @@ -1,11 +1,95 @@ -from typing import Any +import json +from typing import Any, Optional -from langchain.tools import BraveSearch +import requests +from pydantic import BaseModel, Field from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool +class BraveSearchWrapper(BaseModel): + """Wrapper around the Brave search engine.""" + + api_key: str + """The API key to use for the Brave search engine.""" + search_kwargs: dict = Field(default_factory=dict) + """Additional keyword arguments to pass to the search request.""" + base_url = "https://api.search.brave.com/res/v1/web/search" + """The base URL for the Brave search engine.""" + + def run(self, query: str) -> str: + """Query the Brave search engine and return the results as a JSON string. + + Args: + query: The query to search for. + + Returns: The results as a JSON string. + + """ + web_search_results = self._search_request(query=query) + final_results = [ + { + "title": item.get("title"), + "link": item.get("url"), + "snippet": item.get("description"), + } + for item in web_search_results + ] + return json.dumps(final_results) + + def _search_request(self, query: str) -> list[dict]: + headers = { + "X-Subscription-Token": self.api_key, + "Accept": "application/json", + } + req = requests.PreparedRequest() + params = {**self.search_kwargs, **{"q": query}} + req.prepare_url(self.base_url, params) + if req.url is None: + raise ValueError("prepared url is None, this should not happen") + + response = requests.get(req.url, headers=headers) + if not response.ok: + raise Exception(f"HTTP error {response.status_code}") + + return response.json().get("web", {}).get("results", []) + +class BraveSearch(BaseModel): + """Tool that queries the BraveSearch.""" + + name = "brave_search" + description = ( + "a search engine. " + "useful for when you need to answer questions about current events." + " input should be a search query." + ) + search_wrapper: BraveSearchWrapper + + @classmethod + def from_api_key( + cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any + ) -> "BraveSearch": + """Create a tool from an api key. + + Args: + api_key: The api key to use. + search_kwargs: Any additional kwargs to pass to the search wrapper. + **kwargs: Any additional kwargs to pass to the tool. + + Returns: + A tool. + """ + wrapper = BraveSearchWrapper(api_key=api_key, search_kwargs=search_kwargs or {}) + return cls(search_wrapper=wrapper, **kwargs) + + def _run( + self, + query: str, + ) -> str: + """Use the tool.""" + return self.search_wrapper.run(query) + class BraveSearchTool(BuiltinTool): """ Tool for performing a search using Brave search engine. @@ -31,7 +115,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe tool = BraveSearch.from_api_key(api_key=api_key, search_kwargs={"count": count}) - results = tool.run(query) + results = tool._run(query) if not results: return self.create_text_message(f"No results found for '{query}' in Tavily") diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/duckduckgo_search.py b/api/core/tools/provider/builtin/duckduckgo/tools/duckduckgo_search.py index 6046a189300d9..80722a4d6e882 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/duckduckgo_search.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/duckduckgo_search.py @@ -1,16 +1,147 @@ -from typing import Any +from typing import Any, Optional -from langchain.tools import DuckDuckGoSearchRun from pydantic import BaseModel, Field from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool +class DuckDuckGoSearchAPIWrapper(BaseModel): + """Wrapper for DuckDuckGo Search API. + + Free and does not require any setup. + """ + + region: Optional[str] = "wt-wt" + safesearch: str = "moderate" + time: Optional[str] = "y" + max_results: int = 5 + + def get_snippets(self, query: str) -> list[str]: + """Run query through DuckDuckGo and return concatenated results.""" + from duckduckgo_search import DDGS + + with DDGS() as ddgs: + results = ddgs.text( + query, + region=self.region, + safesearch=self.safesearch, + timelimit=self.time, + ) + if results is None: + return ["No good DuckDuckGo Search Result was found"] + snippets = [] + for i, res in enumerate(results, 1): + if res is not None: + snippets.append(res["body"]) + if len(snippets) == self.max_results: + break + return snippets + + def run(self, query: str) -> str: + snippets = self.get_snippets(query) + return " ".join(snippets) + + def results( + self, query: str, num_results: int, backend: str = "api" + ) -> list[dict[str, str]]: + """Run query through DuckDuckGo and return metadata. + + Args: + query: The query to search for. + num_results: The number of results to return. + + Returns: + A list of dictionaries with the following keys: + snippet - The description of the result. + title - The title of the result. + link - The link to the result. + """ + from duckduckgo_search import DDGS + + with DDGS() as ddgs: + results = ddgs.text( + query, + region=self.region, + safesearch=self.safesearch, + timelimit=self.time, + backend=backend, + ) + if results is None: + return [{"Result": "No good DuckDuckGo Search Result was found"}] + + def to_metadata(result: dict) -> dict[str, str]: + if backend == "news": + return { + "date": result["date"], + "title": result["title"], + "snippet": result["body"], + "source": result["source"], + "link": result["url"], + } + return { + "snippet": result["body"], + "title": result["title"], + "link": result["href"], + } + + formatted_results = [] + for i, res in enumerate(results, 1): + if res is not None: + formatted_results.append(to_metadata(res)) + if len(formatted_results) == num_results: + break + return formatted_results + + +class DuckDuckGoSearchRun(BaseModel): + """Tool that queries the DuckDuckGo search API.""" + + name = "duckduckgo_search" + description = ( + "A wrapper around DuckDuckGo Search. " + "Useful for when you need to answer questions about current events. " + "Input should be a search query." + ) + api_wrapper: DuckDuckGoSearchAPIWrapper = Field( + default_factory=DuckDuckGoSearchAPIWrapper + ) + + def _run( + self, + query: str, + ) -> str: + """Use the tool.""" + return self.api_wrapper.run(query) + + +class DuckDuckGoSearchResults(BaseModel): + """Tool that queries the DuckDuckGo search API and gets back json.""" + + name = "DuckDuckGo Results JSON" + description = ( + "A wrapper around Duck Duck Go Search. " + "Useful for when you need to answer questions about current events. " + "Input should be a search query. Output is a JSON array of the query results" + ) + num_results: int = 4 + api_wrapper: DuckDuckGoSearchAPIWrapper = Field( + default_factory=DuckDuckGoSearchAPIWrapper + ) + backend: str = "api" + + def _run( + self, + query: str, + ) -> str: + """Use the tool.""" + res = self.api_wrapper.results(query, self.num_results, backend=self.backend) + res_strs = [", ".join([f"{k}: {v}" for k, v in d.items()]) for d in res] + return ", ".join([f"[{rs}]" for rs in res_strs]) + class DuckDuckGoInput(BaseModel): query: str = Field(..., description="Search query.") - class DuckDuckGoSearchTool(BuiltinTool): """ Tool for performing a search using DuckDuckGo search engine. @@ -34,7 +165,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe tool = DuckDuckGoSearchRun(args_schema=DuckDuckGoInput) - result = tool.run(query) + result = tool._run(query) return self.create_text_message(self.summary(user_id=user_id, content=result)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/google/tools/google_search.py b/api/core/tools/provider/builtin/google/tools/google_search.py index 964c7ef2041f1..0b1978ad3e4f8 100644 --- a/api/core/tools/provider/builtin/google/tools/google_search.py +++ b/api/core/tools/provider/builtin/google/tools/google_search.py @@ -70,43 +70,44 @@ def _process_response(res: dict, typ: str) -> str: raise ValueError(f"Got error from SerpAPI: {res['error']}") if typ == "text": + toret = "" if "answer_box" in res.keys() and type(res["answer_box"]) == list: - res["answer_box"] = res["answer_box"][0] + res["answer_box"] = res["answer_box"][0] + "\n" if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): - toret = res["answer_box"]["answer"] - elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): - toret = res["answer_box"]["snippet"] - elif ( + toret += res["answer_box"]["answer"] + "\n" + if "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): + toret += res["answer_box"]["snippet"] + "\n" + if ( "answer_box" in res.keys() and "snippet_highlighted_words" in res["answer_box"].keys() ): - toret = res["answer_box"]["snippet_highlighted_words"][0] - elif ( + for item in res["answer_box"]["snippet_highlighted_words"]: + toret += item + "\n" + if ( "sports_results" in res.keys() and "game_spotlight" in res["sports_results"].keys() ): - toret = res["sports_results"]["game_spotlight"] - elif ( + toret += res["sports_results"]["game_spotlight"] + "\n" + if ( "shopping_results" in res.keys() and "title" in res["shopping_results"][0].keys() ): - toret = res["shopping_results"][:3] - elif ( + toret += res["shopping_results"][:3] + "\n" + if ( "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys() ): - toret = res["knowledge_graph"]["description"] - elif "snippet" in res["organic_results"][0].keys(): - toret = res["organic_results"][0]["snippet"] - elif "link" in res["organic_results"][0].keys(): - toret = res["organic_results"][0]["link"] - elif ( + toret = res["knowledge_graph"]["description"] + "\n" + if "snippet" in res["organic_results"][0].keys(): + for item in res["organic_results"]: + toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n" + if ( "images_results" in res.keys() and "thumbnail" in res["images_results"][0].keys() ): thumbnails = [item["thumbnail"] for item in res["images_results"][:10]] toret = thumbnails - else: + if toret == "": toret = "No good search result found" elif typ == "link": if "knowledge_graph" in res.keys() and "title" in res["knowledge_graph"].keys() \ diff --git a/api/core/tools/provider/builtin/jina/_assets/icon.svg b/api/core/tools/provider/builtin/jina/_assets/icon.svg new file mode 100644 index 0000000000000..2e1b00fa52e43 --- /dev/null +++ b/api/core/tools/provider/builtin/jina/_assets/icon.svg @@ -0,0 +1,4 @@ + + + + diff --git a/api/core/tools/provider/builtin/jina/jina.py b/api/core/tools/provider/builtin/jina/jina.py new file mode 100644 index 0000000000000..ed1de6f6c1ccc --- /dev/null +++ b/api/core/tools/provider/builtin/jina/jina.py @@ -0,0 +1,12 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class GoogleProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + pass + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/jina/jina.yaml b/api/core/tools/provider/builtin/jina/jina.yaml new file mode 100644 index 0000000000000..6ae3330f40a85 --- /dev/null +++ b/api/core/tools/provider/builtin/jina/jina.yaml @@ -0,0 +1,13 @@ +identity: + author: Dify + name: jina + label: + en_US: JinaReader + zh_Hans: JinaReader + pt_BR: JinaReader + description: + en_US: Convert any URL to an LLM-friendly input. Experience improved output for your agent and RAG systems at no cost. + zh_Hans: 将任何 URL 转换为 LLM 友好的输入。无需付费即可体验为您的 Agent 和 RAG 系统提供的改进输出。 + pt_BR: Converta qualquer URL em uma entrada amigável ao LLM. Experimente uma saída aprimorada para seus sistemas de agente e RAG sem custo. + icon: icon.svg +credentials_for_provider: diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.py b/api/core/tools/provider/builtin/jina/tools/jina_reader.py new file mode 100644 index 0000000000000..322265cefe401 --- /dev/null +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.py @@ -0,0 +1,35 @@ +from typing import Any, Union + +from yarl import URL + +from core.helper import ssrf_proxy +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class JinaReaderTool(BuiltinTool): + _jina_reader_endpoint = 'https://r.jina.ai/' + + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + url = tool_parameters['url'] + + headers = { + 'Accept': 'text/event-stream' + } + + response = ssrf_proxy.get( + str(URL(self._jina_reader_endpoint + url)), + headers=headers, + timeout=(10, 60) + ) + + if tool_parameters.get('summary', False): + return self.create_text_message(self.summary(user_id, response.text)) + + return self.create_text_message(response.text) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml b/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml new file mode 100644 index 0000000000000..38d66292dfe35 --- /dev/null +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml @@ -0,0 +1,41 @@ +identity: + name: jina_reader + author: Dify + label: + en_US: JinaReader + zh_Hans: JinaReader + pt_BR: JinaReader +description: + human: + en_US: Convert any URL to an LLM-friendly input. Experience improved output for your agent and RAG systems at no cost. + zh_Hans: 将任何 URL 转换为 LLM 友好的输入。无需付费即可体验为您的 Agent 和 RAG 系统提供的改进输出。 + pt_BR: Converta qualquer URL em uma entrada amigável ao LLM. Experimente uma saída aprimorada para seus sistemas de agente e RAG sem custo. + llm: A tool for scraping webpages. Input should be a URL. +parameters: + - name: url + type: string + required: true + label: + en_US: URL + zh_Hans: 网页链接 + pt_BR: URL + human_description: + en_US: used for linking to webpages + zh_Hans: 用于链接到网页 + pt_BR: used for linking to webpages + llm_description: url for scraping + form: llm + - name: summary + type: boolean + required: false + default: false + label: + en_US: Enable summary + zh_Hans: 是否启用摘要 + pt_BR: Habilitar resumo + human_description: + en_US: Enable summary for the output + zh_Hans: 为输出启用摘要 + pt_BR: Habilitar resumo para a saída + llm_description: enable summary + form: form diff --git a/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py b/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py index 1bed1fa77c284..ee465d9bca1a5 100644 --- a/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py +++ b/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py @@ -1,16 +1,187 @@ +import json +import time +import urllib.error +import urllib.parse +import urllib.request from typing import Any -from langchain.tools import PubmedQueryRun from pydantic import BaseModel, Field from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool +class PubMedAPIWrapper(BaseModel): + """ + Wrapper around PubMed API. + + This wrapper will use the PubMed API to conduct searches and fetch + document summaries. By default, it will return the document summaries + of the top-k results of an input search. + + Parameters: + top_k_results: number of the top-scored document used for the PubMed tool + load_max_docs: a limit to the number of loaded documents + load_all_available_meta: + if True: the `metadata` of the loaded Documents gets all available meta info + (see https://www.ncbi.nlm.nih.gov/books/NBK25499/#chapter4.ESearch) + if False: the `metadata` gets only the most informative fields. + """ + + base_url_esearch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?" + base_url_efetch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?" + max_retry = 5 + sleep_time = 0.2 + + # Default values for the parameters + top_k_results: int = 3 + load_max_docs: int = 25 + ARXIV_MAX_QUERY_LENGTH = 300 + doc_content_chars_max: int = 2000 + load_all_available_meta: bool = False + email: str = "your_email@example.com" + + def run(self, query: str) -> str: + """ + Run PubMed search and get the article meta information. + See https://www.ncbi.nlm.nih.gov/books/NBK25499/#chapter4.ESearch + It uses only the most informative fields of article meta information. + """ + + try: + # Retrieve the top-k results for the query + docs = [ + f"Published: {result['pub_date']}\nTitle: {result['title']}\n" + f"Summary: {result['summary']}" + for result in self.load(query[: self.ARXIV_MAX_QUERY_LENGTH]) + ] + + # Join the results and limit the character count + return ( + "\n\n".join(docs)[:self.doc_content_chars_max] + if docs + else "No good PubMed Result was found" + ) + except Exception as ex: + return f"PubMed exception: {ex}" + + def load(self, query: str) -> list[dict]: + """ + Search PubMed for documents matching the query. + Return a list of dictionaries containing the document metadata. + """ + + url = ( + self.base_url_esearch + + "db=pubmed&term=" + + str({urllib.parse.quote(query)}) + + f"&retmode=json&retmax={self.top_k_results}&usehistory=y" + ) + result = urllib.request.urlopen(url) + text = result.read().decode("utf-8") + json_text = json.loads(text) + + articles = [] + webenv = json_text["esearchresult"]["webenv"] + for uid in json_text["esearchresult"]["idlist"]: + article = self.retrieve_article(uid, webenv) + articles.append(article) + + # Convert the list of articles to a JSON string + return articles + + def retrieve_article(self, uid: str, webenv: str) -> dict: + url = ( + self.base_url_efetch + + "db=pubmed&retmode=xml&id=" + + uid + + "&webenv=" + + webenv + ) + + retry = 0 + while True: + try: + result = urllib.request.urlopen(url) + break + except urllib.error.HTTPError as e: + if e.code == 429 and retry < self.max_retry: + # Too Many Requests error + # wait for an exponentially increasing amount of time + print( + f"Too Many Requests, " + f"waiting for {self.sleep_time:.2f} seconds..." + ) + time.sleep(self.sleep_time) + self.sleep_time *= 2 + retry += 1 + else: + raise e + + xml_text = result.read().decode("utf-8") + + # Get title + title = "" + if "" in xml_text and "" in xml_text: + start_tag = "" + end_tag = "" + title = xml_text[ + xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) + ] + + # Get abstract + abstract = "" + if "" in xml_text and "" in xml_text: + start_tag = "" + end_tag = "" + abstract = xml_text[ + xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) + ] + + # Get publication date + pub_date = "" + if "" in xml_text and "" in xml_text: + start_tag = "" + end_tag = "" + pub_date = xml_text[ + xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) + ] + + # Return article as dictionary + article = { + "uid": uid, + "title": title, + "summary": abstract, + "pub_date": pub_date, + } + return article + + +class PubmedQueryRun(BaseModel): + """Tool that searches the PubMed API.""" + + name = "PubMed" + description = ( + "A wrapper around PubMed.org " + "Useful for when you need to answer questions about Physics, Mathematics, " + "Computer Science, Quantitative Biology, Quantitative Finance, Statistics, " + "Electrical Engineering, and Economics " + "from scientific articles on PubMed.org. " + "Input should be a search query." + ) + api_wrapper: PubMedAPIWrapper = Field(default_factory=PubMedAPIWrapper) + + def _run( + self, + query: str, + ) -> str: + """Use the Arxiv tool.""" + return self.api_wrapper.run(query) + + class PubMedInput(BaseModel): query: str = Field(..., description="Search query.") - class PubMedSearchTool(BuiltinTool): """ Tool for performing a search using PubMed search engine. @@ -34,7 +205,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe tool = PubmedQueryRun(args_schema=PubMedInput) - result = tool.run(query) + result = tool._run(query) return self.create_text_message(self.summary(user_id=user_id, content=result)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/searxng/_assets/icon.svg b/api/core/tools/provider/builtin/searxng/_assets/icon.svg new file mode 100644 index 0000000000000..b94fe3728adbf --- /dev/null +++ b/api/core/tools/provider/builtin/searxng/_assets/icon.svg @@ -0,0 +1,56 @@ + + + + + + + image/svg+xml + + + + + + + + + + + + diff --git a/api/core/tools/provider/builtin/searxng/searxng.py b/api/core/tools/provider/builtin/searxng/searxng.py new file mode 100644 index 0000000000000..8046056093cd4 --- /dev/null +++ b/api/core/tools/provider/builtin/searxng/searxng.py @@ -0,0 +1,25 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.searxng.tools.searxng_search import SearXNGSearchTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class SearXNGProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + SearXNGSearchTool().fork_tool_runtime( + meta={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_parameters={ + "query": "SearXNG", + "limit": 1, + "search_type": "page", + "result_type": "link" + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/searxng/searxng.yaml b/api/core/tools/provider/builtin/searxng/searxng.yaml new file mode 100644 index 0000000000000..c8c713cf04420 --- /dev/null +++ b/api/core/tools/provider/builtin/searxng/searxng.yaml @@ -0,0 +1,24 @@ +identity: + author: Junytang + name: searxng + label: + en_US: SearXNG + zh_Hans: SearXNG + description: + en_US: A free internet metasearch engine. + zh_Hans: 开源互联网元搜索引擎 + icon: icon.svg +credentials_for_provider: + searxng_base_url: + type: secret-input + required: true + label: + en_US: SearXNG base URL + zh_Hans: SearXNG base URL + help: + en_US: Please input your SearXNG base URL + zh_Hans: 请输入您的 SearXNG base URL + placeholder: + en_US: Please input your SearXNG base URL + zh_Hans: 请输入您的 SearXNG base URL + url: https://docs.dify.ai/tutorials/tool-configuration/searxng diff --git a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py new file mode 100644 index 0000000000000..cbc5ab435a69a --- /dev/null +++ b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py @@ -0,0 +1,124 @@ +import json +from typing import Any + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class SearXNGSearchResults(dict): + """Wrapper for search results.""" + + def __init__(self, data: str): + super().__init__(json.loads(data)) + self.__dict__ = self + + @property + def results(self) -> Any: + return self.get("results", []) + + +class SearXNGSearchTool(BuiltinTool): + """ + Tool for performing a search using SearXNG engine. + """ + + SEARCH_TYPE = { + "page": "general", + "news": "news", + "image": "images", + # "video": "videos", + # "file": "files" + } + LINK_FILED = { + "page": "url", + "news": "url", + "image": "img_src", + # "video": "iframe_src", + # "file": "magnetlink" + } + TEXT_FILED = { + "page": "content", + "news": "content", + "image": "img_src", + # "video": "iframe_src", + # "file": "magnetlink" + } + + def _invoke_query(self, user_id: str, host: str, query: str, search_type: str, result_type: str, topK: int = 5) -> list[dict]: + """Run query and return the results.""" + + search_type = search_type.lower() + if search_type not in self.SEARCH_TYPE.keys(): + search_type= "page" + + response = requests.get(host, params={ + "q": query, + "format": "json", + "categories": self.SEARCH_TYPE[search_type] + }) + + if response.status_code != 200: + raise Exception(f'Error {response.status_code}: {response.text}') + + search_results = SearXNGSearchResults(response.text).results[:topK] + + if result_type == 'link': + results = [] + if search_type == "page" or search_type == "news": + for r in search_results: + results.append(self.create_text_message( + text=f'{r["title"]}: {r.get(self.LINK_FILED[search_type], "")}' + )) + elif search_type == "image": + for r in search_results: + results.append(self.create_image_message( + image=r.get(self.LINK_FILED[search_type], "") + )) + else: + for r in search_results: + results.append(self.create_link_message( + link=r.get(self.LINK_FILED[search_type], "") + )) + + return results + else: + text = '' + for i, r in enumerate(search_results): + text += f'{i+1}: {r["title"]} - {r.get(self.TEXT_FILED[search_type], "")}\n' + + return self.create_text_message(text=self.summary(user_id=user_id, content=text)) + + + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + """ + Invoke the SearXNG search tool. + + Args: + user_id (str): The ID of the user invoking the tool. + tool_parameters (dict[str, Any]): The parameters for the tool invocation. + + Returns: + ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation. + """ + + host = self.runtime.credentials.get('searxng_base_url', None) + if not host: + raise Exception('SearXNG api is required') + + query = tool_parameters.get('query', None) + if not query: + return self.create_text_message('Please input query') + + num_results = min(tool_parameters.get('num_results', 5), 20) + search_type = tool_parameters.get('search_type', 'page') or 'page' + result_type = tool_parameters.get('result_type', 'text') or 'text' + + return self._invoke_query( + user_id=user_id, + host=host, + query=query, + search_type=search_type, + result_type=result_type, + topK=num_results) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/searxng/tools/searxng_search.yaml b/api/core/tools/provider/builtin/searxng/tools/searxng_search.yaml new file mode 100644 index 0000000000000..0edf1744f4b2f --- /dev/null +++ b/api/core/tools/provider/builtin/searxng/tools/searxng_search.yaml @@ -0,0 +1,89 @@ +identity: + name: searxng_search + author: Tice + label: + en_US: SearXNG Search + zh_Hans: SearXNG 搜索 +description: + human: + en_US: Perform searches on SearXNG and get results. + zh_Hans: 在 SearXNG 上进行搜索并获取结果。 + llm: Perform searches on SearXNG and get results. +parameters: + - name: query + type: string + required: true + label: + en_US: Query string + zh_Hans: 查询语句 + human_description: + en_US: The search query. + zh_Hans: 搜索查询语句。 + llm_description: Key words for searching + form: llm + - name: search_type + type: select + required: true + label: + en_US: search type + zh_Hans: 搜索类型 + pt_BR: search type + human_description: + en_US: search type for page, news or image. + zh_Hans: 选择搜索的类型:网页,新闻,图片。 + pt_BR: search type for page, news or image. + default: Page + options: + - value: Page + label: + en_US: Page + zh_Hans: 网页 + pt_BR: Page + - value: News + label: + en_US: News + zh_Hans: 新闻 + pt_BR: News + - value: Image + label: + en_US: Image + zh_Hans: 图片 + pt_BR: Image + form: form + - name: num_results + type: number + required: true + label: + en_US: Number of query results + zh_Hans: 返回查询数量 + human_description: + en_US: The number of query results. + zh_Hans: 返回查询结果的数量。 + form: form + default: 5 + min: 1 + max: 20 + - name: result_type + type: select + required: true + label: + en_US: result type + zh_Hans: 结果类型 + pt_BR: result type + human_description: + en_US: return a list of links or texts. + zh_Hans: 返回一个连接列表还是纯文本内容。 + pt_BR: return a list of links or texts. + default: text + options: + - value: link + label: + en_US: Link + zh_Hans: 链接 + pt_BR: Link + - value: text + label: + en_US: Text + zh_Hans: 文本 + pt_BR: Text + form: form diff --git a/api/core/tools/provider/builtin/twilio/tools/send_message.py b/api/core/tools/provider/builtin/twilio/tools/send_message.py index 984ac3e90697a..24502a3565b55 100644 --- a/api/core/tools/provider/builtin/twilio/tools/send_message.py +++ b/api/core/tools/provider/builtin/twilio/tools/send_message.py @@ -1,11 +1,70 @@ -from typing import Any, Union +from typing import Any, Optional, Union -from langchain.utilities import TwilioAPIWrapper +from pydantic import BaseModel, validator from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool +class TwilioAPIWrapper(BaseModel): + """Messaging Client using Twilio. + + To use, you should have the ``twilio`` python package installed, + and the environment variables ``TWILIO_ACCOUNT_SID``, ``TWILIO_AUTH_TOKEN``, and + ``TWILIO_FROM_NUMBER``, or pass `account_sid`, `auth_token`, and `from_number` as + named parameters to the constructor. + """ + + client: Any #: :meta private: + account_sid: Optional[str] = None + """Twilio account string identifier.""" + auth_token: Optional[str] = None + """Twilio auth token.""" + from_number: Optional[str] = None + """A Twilio phone number in [E.164](https://www.twilio.com/docs/glossary/what-e164) + format, an + [alphanumeric sender ID](https://www.twilio.com/docs/sms/send-messages#use-an-alphanumeric-sender-id), + or a [Channel Endpoint address](https://www.twilio.com/docs/sms/channels#channel-addresses) + that is enabled for the type of message you want to send. Phone numbers or + [short codes](https://www.twilio.com/docs/sms/api/short-code) purchased from + Twilio also work here. You cannot, for example, spoof messages from a private + cell phone number. If you are using `messaging_service_sid`, this parameter + must be empty. + """ # noqa: E501 + + @validator("client", pre=True, always=True) + def set_validator(cls, values: dict) -> dict: + """Validate that api key and python package exists in environment.""" + try: + from twilio.rest import Client + except ImportError: + raise ImportError( + "Could not import twilio python package. " + "Please install it with `pip install twilio`." + ) + account_sid = values.get("account_sid") + auth_token = values.get("auth_token") + values["from_number"] = values.get("from_number") + values["client"] = Client(account_sid, auth_token) + + return values + + def run(self, body: str, to: str) -> str: + """Run body through Twilio and respond with message sid. + + Args: + body: The text of the message you want to send. Can be up to 1,600 + characters in length. + to: The destination phone number in + [E.164](https://www.twilio.com/docs/glossary/what-e164) format for + SMS/MMS or + [Channel user address](https://www.twilio.com/docs/sms/channels#channel-addresses) + for other 3rd-party channels. + """ # noqa: E501 + message = self.client.messages.create(to, from_=self.from_number, body=body) + return message.sid + + class SendMessageTool(BuiltinTool): """ A tool for sending messages using Twilio API. diff --git a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py index 38b495ad6f695..ef2990bfe45a5 100644 --- a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py +++ b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py @@ -1,16 +1,79 @@ -from typing import Any, Union +from typing import Any, Optional, Union -from langchain import WikipediaAPIWrapper -from langchain.tools import WikipediaQueryRun -from pydantic import BaseModel, Field +import wikipedia from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool +WIKIPEDIA_MAX_QUERY_LENGTH = 300 -class WikipediaInput(BaseModel): - query: str = Field(..., description="search query.") +class WikipediaAPIWrapper: + """Wrapper around WikipediaAPI. + To use, you should have the ``wikipedia`` python package installed. + This wrapper will use the Wikipedia API to conduct searches and + fetch page summaries. By default, it will return the page summaries + of the top-k results. + It limits the Document content by doc_content_chars_max. + """ + + top_k_results: int = 3 + lang: str = "en" + load_all_available_meta: bool = False + doc_content_chars_max: int = 4000 + + def __init__(self, doc_content_chars_max: int = 4000): + self.doc_content_chars_max = doc_content_chars_max + + def run(self, query: str) -> str: + wikipedia.set_lang(self.lang) + wiki_client = wikipedia + + """Run Wikipedia search and get page summaries.""" + page_titles = wiki_client.search(query[:WIKIPEDIA_MAX_QUERY_LENGTH]) + summaries = [] + for page_title in page_titles[: self.top_k_results]: + if wiki_page := self._fetch_page(page_title): + if summary := self._formatted_page_summary(page_title, wiki_page): + summaries.append(summary) + if not summaries: + return "No good Wikipedia Search Result was found" + return "\n\n".join(summaries)[: self.doc_content_chars_max] + + @staticmethod + def _formatted_page_summary(page_title: str, wiki_page: Any) -> Optional[str]: + return f"Page: {page_title}\nSummary: {wiki_page.summary}" + + def _fetch_page(self, page: str) -> Optional[str]: + try: + return wikipedia.page(title=page, auto_suggest=False) + except ( + wikipedia.exceptions.PageError, + wikipedia.exceptions.DisambiguationError, + ): + return None + +class WikipediaQueryRun: + """Tool that searches the Wikipedia API.""" + + name = "Wikipedia" + description = ( + "A wrapper around Wikipedia. " + "Useful for when you need to answer general questions about " + "people, places, companies, facts, historical events, or other subjects. " + "Input should be a search query." + ) + api_wrapper: WikipediaAPIWrapper + + def __init__(self, api_wrapper: WikipediaAPIWrapper): + self.api_wrapper = api_wrapper + + def _run( + self, + query: str, + ) -> str: + """Use the Wikipedia tool.""" + return self.api_wrapper.run(query) class WikiPediaSearchTool(BuiltinTool): def _invoke(self, user_id: str, @@ -24,14 +87,10 @@ def _invoke(self, return self.create_text_message('Please input query') tool = WikipediaQueryRun( - name="wikipedia", api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000), - args_schema=WikipediaInput ) - result = tool.run(tool_input={ - 'query': query - }) + result = tool._run(query) return self.create_text_message(self.summary(user_id=user_id,content=result)) \ No newline at end of file diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index de3cd552d4bb1..f7b963a92e886 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -1,5 +1,6 @@ import json from json import dumps +from os import getenv from typing import Any, Union from urllib.parse import urlencode @@ -13,7 +14,10 @@ from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError from core.tools.tool.tool import Tool -API_TOOL_DEFAULT_TIMEOUT = (10, 60) +API_TOOL_DEFAULT_TIMEOUT = ( + int(getenv('API_TOOL_DEFAULT_CONNECT_TIMEOUT', '10')), + int(getenv('API_TOOL_DEFAULT_READ_TIMEOUT', '60')) +) class ApiTool(Tool): api_bundle: ApiBasedToolBundle @@ -287,6 +291,16 @@ def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> A elif property['type'] == 'null': if value is None: return None + elif property['type'] == 'object': + if isinstance(value, str): + try: + return json.loads(value) + except ValueError: + return value + elif isinstance(value, dict): + return value + else: + return value else: raise ValueError(f"Invalid type {property['type']} for property {property}") elif 'anyOf' in property and isinstance(property['anyOf'], list): diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index d9934acff9c61..6e11427d58ac2 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,8 +1,6 @@ import threading -from typing import Optional from flask import Flask, current_app -from langchain.tools import BaseTool from pydantic import BaseModel, Field from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler @@ -10,6 +8,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.retrieval_service import RetrievalService from core.rerank.rerank import RerankRunner +from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -29,25 +28,20 @@ class DatasetMultiRetrieverToolInput(BaseModel): query: str = Field(..., description="dataset multi retriever and rerank") -class DatasetMultiRetrieverTool(BaseTool): +class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): """Tool for querying multi dataset.""" - name: str = "dataset-" + name: str = "dataset_" args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput description: str = "dataset multi retriever and rerank. " - tenant_id: str dataset_ids: list[str] - top_k: int = 2 - score_threshold: Optional[float] = None reranking_provider_name: str reranking_model_name: str - return_resource: bool - retriever_from: str - hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] + @classmethod def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): return cls( - name=f'dataset-{tenant_id}', + name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs @@ -149,9 +143,6 @@ def _run(self, query: str) -> str: return str("\n".join(document_context_list)) - async def _arun(self, tool_input: str) -> str: - raise NotImplementedError() - def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list, hit_callbacks: list[DatasetIndexToolCallbackHandler]): with flask_app.app_context(): diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py new file mode 100644 index 0000000000000..1f8478f5541ac --- /dev/null +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py @@ -0,0 +1,34 @@ +from abc import abstractmethod +from typing import Any, Optional + +from msal_extensions.persistence import ABC +from pydantic import BaseModel + +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler + + +class DatasetRetrieverBaseTool(BaseModel, ABC): + """Tool for querying a Dataset.""" + name: str = "dataset" + description: str = "use this to retrieve a dataset. " + tenant_id: str + top_k: int = 2 + score_threshold: Optional[float] = None + hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] + return_resource: bool + retriever_from: str + + class Config: + arbitrary_types_allowed = True + + @abstractmethod + def _run( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + """Use the tool. + + Add run_manager: Optional[CallbackManagerForToolRun] = None + to child implementations to enable tracing, + """ diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index 13331d981bbec..552174e0bad82 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -1,10 +1,8 @@ -from typing import Optional -from langchain.tools import BaseTool from pydantic import BaseModel, Field -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.rag.datasource.retrieval_service import RetrievalService +from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -24,19 +22,13 @@ class DatasetRetrieverToolInput(BaseModel): query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.") -class DatasetRetrieverTool(BaseTool): +class DatasetRetrieverTool(DatasetRetrieverBaseTool): """Tool for querying a Dataset.""" name: str = "dataset" args_schema: type[BaseModel] = DatasetRetrieverToolInput description: str = "use this to retrieve a dataset. " - - tenant_id: str dataset_id: str - top_k: int = 2 - score_threshold: Optional[float] = None - hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] - return_resource: bool - retriever_from: str + @classmethod def from_dataset(cls, dataset: Dataset, **kwargs): @@ -46,7 +38,7 @@ def from_dataset(cls, dataset: Dataset, **kwargs): description = description.replace('\n', '').replace('\r', '') return cls( - name=f'dataset-{dataset.id}', + name=f"dataset_{dataset.id.replace('-', '_')}", tenant_id=dataset.tenant_id, dataset_id=dataset.id, description=description, @@ -153,7 +145,4 @@ def _run(self, query: str) -> str: for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(context_list) - return str("\n".join(document_context_list)) - - async def _arun(self, tool_input: str) -> str: - raise NotImplementedError() + return str("\n".join(document_context_list)) \ No newline at end of file diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index 421f8a0483379..e52981b2d1459 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -1,7 +1,5 @@ from typing import Any -from langchain.tools import BaseTool - from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler @@ -14,11 +12,12 @@ ToolParameter, ToolProviderType, ) +from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from core.tools.tool.tool import Tool class DatasetRetrieverTool(Tool): - langchain_tool: BaseTool + retrival_tool: DatasetRetrieverBaseTool @staticmethod def get_dataset_tools(tenant_id: str, @@ -43,7 +42,7 @@ def get_dataset_tools(tenant_id: str, # Agent only support SINGLE mode original_retriever_mode = retrieve_config.retrieve_strategy retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE - langchain_tools = feature.to_dataset_retriever_tool( + retrival_tools = feature.to_dataset_retriever_tool( tenant_id=tenant_id, dataset_ids=dataset_ids, retrieve_config=retrieve_config, @@ -54,17 +53,17 @@ def get_dataset_tools(tenant_id: str, # restore retrieve strategy retrieve_config.retrieve_strategy = original_retriever_mode - # convert langchain tools to Tools + # convert retrival tools to Tools tools = [] - for langchain_tool in langchain_tools: + for retrival_tool in retrival_tools: tool = DatasetRetrieverTool( - langchain_tool=langchain_tool, - identity=ToolIdentity(provider='', author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')), + retrival_tool=retrival_tool, + identity=ToolIdentity(provider='', author='', name=retrival_tool.name, label=I18nObject(en_US='', zh_Hans='')), parameters=[], is_team_authorization=True, description=ToolDescription( human=I18nObject(en_US='', zh_Hans=''), - llm=langchain_tool.description), + llm=retrival_tool.description), runtime=DatasetRetrieverTool.Runtime() ) @@ -96,7 +95,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe return self.create_text_message(text='please input query') # invoke dataset retriever tool - result = self.langchain_tool._run(query=query) + result = self.retrival_tool._run(query=query) return self.create_text_message(text=result) diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index f26f9494a1e97..4203180992b4d 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -243,8 +243,21 @@ def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> di tool_parameters[parameter.name] = float(tool_parameters[parameter.name]) elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN: if not isinstance(tool_parameters[parameter.name], bool): - tool_parameters[parameter.name] = bool(tool_parameters[parameter.name]) - + # check if it is a string + if isinstance(tool_parameters[parameter.name], str): + # check true false + if tool_parameters[parameter.name].lower() in ['true', 'false']: + tool_parameters[parameter.name] = tool_parameters[parameter.name].lower() == 'true' + # check 1 0 + elif tool_parameters[parameter.name] in ['1', '0']: + tool_parameters[parameter.name] = tool_parameters[parameter.name] == '1' + else: + tool_parameters[parameter.name] = bool(tool_parameters[parameter.name]) + elif isinstance(tool_parameters[parameter.name], int | float): + tool_parameters[parameter.name] = tool_parameters[parameter.name] != 0 + else: + tool_parameters[parameter.name] = bool(tool_parameters[parameter.name]) + return tool_parameters @abstractmethod diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 65e7765adc37a..f96d7940bda69 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -38,7 +38,7 @@ def agent_invoke(tool: Tool, tool_parameters: Union[str, dict], if isinstance(tool_parameters, str): # check if this tool has only one parameter parameters = [ - parameter for parameter in tool.parameters + parameter for parameter in tool.get_runtime_parameters() if parameter.form == ToolParameter.ToolParameterForm.LLM ] if parameters and len(parameters) == 1: diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index dff8505c1bd77..690bdddaf6a90 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -22,9 +22,6 @@ class ValueType(Enum): class VariablePool: - variables_mapping = {} - user_inputs: dict - system_variables: dict[SystemVariable, Any] def __init__(self, system_variables: dict[SystemVariable, Any], user_inputs: dict) -> None: @@ -34,6 +31,7 @@ def __init__(self, system_variables: dict[SystemVariable, Any], # 'query': 'abc', # 'files': [] # } + self.variables_mapping = {} self.user_inputs = user_inputs self.system_variables = system_variables for system_variable, value in system_variables.items(): diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 097dbb660c8cd..bc1b8d7ce1e7b 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -234,6 +234,9 @@ def _transform_result(self, result: dict, output_schema: Optional[dict[str, Code parameters_validated = {} for output_name, output_config in output_schema.items(): dot = '.' if prefix else '' + if output_name not in result: + raise ValueError(f'Output {prefix}{dot}{output_name} is missing.') + if output_config.type == 'object': # check if output is object if not isinstance(result.get(output_name), dict): diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index e9aa65540117d..be3cec91525ab 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,28 +1,21 @@ -import threading from typing import Any, cast -from flask import Flask, current_app - from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.message_entities import PromptMessageTool from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.rag.datasource.retrieval_service import RetrievalService -from core.rerank.rerank import RerankRunner +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData -from core.workflow.nodes.knowledge_retrieval.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter -from core.workflow.nodes.knowledge_retrieval.multi_dataset_react_route import ReactMultiDatasetRouter from extensions.ext_database import db -from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment +from models.dataset import Dataset, Document, DocumentSegment from models.workflow import WorkflowNodeExecutionStatus default_retrieval_model = { @@ -106,10 +99,45 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: available_datasets.append(dataset) all_documents = [] + dataset_retrieval = DatasetRetrieval() if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value: - all_documents = self._single_retrieve(available_datasets, node_data, query) + # fetch model config + model_instance, model_config = self._fetch_model_config(node_data) + # check model is support tool calling + model_type_instance = model_config.provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + # get model schema + model_schema = model_type_instance.get_model_schema( + model=model_config.model, + credentials=model_config.credentials + ) + + if model_schema: + planning_strategy = PlanningStrategy.REACT_ROUTER + features = model_schema.features + if features: + if ModelFeature.TOOL_CALL in features \ + or ModelFeature.MULTI_TOOL_CALL in features: + planning_strategy = PlanningStrategy.ROUTER + all_documents = dataset_retrieval.single_retrieve( + available_datasets=available_datasets, + tenant_id=self.tenant_id, + user_id=self.user_id, + app_id=self.app_id, + user_from=self.user_from.value, + query=query, + model_config=model_config, + model_instance=model_instance, + planning_strategy=planning_strategy + ) elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: - all_documents = self._multiple_retrieve(available_datasets, node_data, query) + all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id, + self.user_from.value, + available_datasets, query, + node_data.multiple_retrieval_config.top_k, + node_data.multiple_retrieval_config.score_threshold, + node_data.multiple_retrieval_config.reranking_model.provider, + node_data.multiple_retrieval_config.reranking_model.model) context_list = [] if all_documents: @@ -184,84 +212,6 @@ def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) variable_mapping['query'] = node_data.query_variable_selector return variable_mapping - def _single_retrieve(self, available_datasets, node_data, query): - tools = [] - for dataset in available_datasets: - description = dataset.description - if not description: - description = 'useful for when you want to answer queries about the ' + dataset.name - - description = description.replace('\n', '').replace('\r', '') - message_tool = PromptMessageTool( - name=dataset.id, - description=description, - parameters={ - "type": "object", - "properties": {}, - "required": [], - } - ) - tools.append(message_tool) - # fetch model config - model_instance, model_config = self._fetch_model_config(node_data) - # check model is support tool calling - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - # get model schema - model_schema = model_type_instance.get_model_schema( - model=model_config.model, - credentials=model_config.credentials - ) - - if not model_schema: - return None - planning_strategy = PlanningStrategy.REACT_ROUTER - features = model_schema.features - if features: - if ModelFeature.TOOL_CALL in features \ - or ModelFeature.MULTI_TOOL_CALL in features: - planning_strategy = PlanningStrategy.ROUTER - dataset_id = None - if planning_strategy == PlanningStrategy.REACT_ROUTER: - react_multi_dataset_router = ReactMultiDatasetRouter() - dataset_id = react_multi_dataset_router.invoke(query, tools, node_data, model_config, model_instance, - self.user_id, self.tenant_id) - - elif planning_strategy == PlanningStrategy.ROUTER: - function_call_router = FunctionCallMultiDatasetRouter() - dataset_id = function_call_router.invoke(query, tools, model_config, model_instance) - if dataset_id: - # get retrieval model config - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() - if dataset: - retrieval_model_config = dataset.retrieval_model \ - if dataset.retrieval_model else default_retrieval_model - - # get top k - top_k = retrieval_model_config['top_k'] - # get retrieval method - retrival_method = retrieval_model_config['search_method'] - # get reranking model - reranking_model=retrieval_model_config['reranking_model'] \ - if retrieval_model_config['reranking_enable'] else None - # get score threshold - score_threshold = .0 - score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") - if score_threshold_enabled: - score_threshold = retrieval_model_config.get("score_threshold") - - results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, - query=query, - top_k=top_k, score_threshold=score_threshold, - reranking_model=reranking_model) - self._on_query(query, [dataset_id]) - if results: - self._on_retrival_end(results) - return results - return [] - def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ ModelInstance, ModelConfigWithCredentialsEntity]: """ @@ -332,112 +282,3 @@ def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ parameters=completion_params, stop=stop, ) - - def _multiple_retrieve(self, available_datasets, node_data, query): - threads = [] - all_documents = [] - dataset_ids = [dataset.id for dataset in available_datasets] - for dataset in available_datasets: - retrieval_thread = threading.Thread(target=self._retriever, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset.id, - 'query': query, - 'top_k': node_data.multiple_retrieval_config.top_k, - 'all_documents': all_documents, - }) - threads.append(retrieval_thread) - retrieval_thread.start() - for thread in threads: - thread.join() - # do rerank for searched documents - model_manager = ModelManager() - rerank_model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, - provider=node_data.multiple_retrieval_config.reranking_model.provider, - model_type=ModelType.RERANK, - model=node_data.multiple_retrieval_config.reranking_model.model - ) - - rerank_runner = RerankRunner(rerank_model_instance) - all_documents = rerank_runner.run(query, all_documents, - node_data.multiple_retrieval_config.score_threshold, - node_data.multiple_retrieval_config.top_k) - self._on_query(query, dataset_ids) - if all_documents: - self._on_retrival_end(all_documents) - return all_documents - - def _on_retrival_end(self, documents: list[Document]) -> None: - """Handle retrival end.""" - for document in documents: - query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata['doc_id'] - ) - - # if 'dataset_id' in document.metadata: - if 'dataset_id' in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) - - # add hit count to document segment - query.update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False - ) - - db.session.commit() - - def _on_query(self, query: str, dataset_ids: list[str]) -> None: - """ - Handle query. - """ - if not query: - return - for dataset_id in dataset_ids: - dataset_query = DatasetQuery( - dataset_id=dataset_id, - content=query, - source='app', - source_app_id=self.app_id, - created_by_role=self.user_from.value, - created_by=self.user_id - ) - db.session.add(dataset_query) - db.session.commit() - - def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): - with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id == dataset_id - ).first() - - if not dataset: - return [] - - # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model - - if dataset.indexing_technique == "economy": - # use keyword table query - documents = RetrievalService.retrieve(retrival_method='keyword_search', - dataset_id=dataset.id, - query=query, - top_k=top_k - ) - if documents: - all_documents.extend(documents) - else: - if top_k > 0: - # retrieval source - documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], - dataset_id=dataset.id, - query=query, - top_k=top_k, - score_threshold=retrieval_model['score_threshold'] - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model['reranking_model'] - if retrieval_model['reranking_enable'] else None - ) - - all_documents.extend(documents) - diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index cf2f9b7176e7a..491e984477255 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -10,7 +10,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.entities.message_entities import PromptMessage +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageContentType from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder @@ -248,16 +248,19 @@ def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> context_str = '' original_retriever_resource = [] for item in context_value: - if 'content' not in item: - raise ValueError(f'Invalid context structure: {item}') + if isinstance(item, str): + context_str += item + '\n' + else: + if 'content' not in item: + raise ValueError(f'Invalid context structure: {item}') - context_str += item['content'] + '\n' + context_str += item['content'] + '\n' - retriever_resource = self._convert_to_original_retriever_resource(item) - if retriever_resource: - original_retriever_resource.append(retriever_resource) + retriever_resource = self._convert_to_original_retriever_resource(item) + if retriever_resource: + original_retriever_resource.append(retriever_resource) - if self.callbacks: + if self.callbacks and original_retriever_resource: for callback in self.callbacks: callback.on_event( event=QueueRetrieverResourcesEvent( @@ -434,6 +437,22 @@ def _fetch_prompt_messages(self, node_data: LLMNodeData, ) stop = model_config.stop + vision_enabled = node_data.vision.enabled + for prompt_message in prompt_messages: + if not isinstance(prompt_message.content, str): + prompt_message_content = [] + for content_item in prompt_message.content: + if vision_enabled and content_item.type == PromptMessageContentType.IMAGE: + prompt_message_content.append(content_item) + elif content_item.type == PromptMessageContentType.TEXT: + prompt_message_content.append(content_item) + + if len(prompt_message_content) > 1: + prompt_message.content = prompt_message_content + elif (len(prompt_message_content) == 1 + and prompt_message_content[0].type == PromptMessageContentType.TEXT): + prompt_message.content = prompt_message_content[0].data + return prompt_messages, stop @classmethod diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 751aac011a8a8..6449e2c11c7a9 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -13,7 +13,7 @@ from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.llm.llm_node import LLMNode from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData @@ -65,7 +65,9 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: categories = [_class.name for _class in node_data.classes] try: result_text_json = json.loads(result_text.strip('```JSON\n')) - categories = result_text_json.get('categories', []) + categories_result = result_text_json.get('categories', []) + if categories_result: + categories = categories_result except Exception: logging.error(f"Failed to parse result text: {result_text}") try: @@ -89,14 +91,24 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: inputs=variables, process_data=process_data, outputs=outputs, - edge_source_handle=classes_map.get(categories[0], None) + edge_source_handle=classes_map.get(categories[0], None), + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, + NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, + NodeRunMetadataKey.CURRENCY: usage.currency + } ) except ValueError as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, - error=str(e) + error=str(e), + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, + NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, + NodeRunMetadataKey.CURRENCY: usage.currency + } ) @classmethod @@ -199,7 +211,7 @@ def _get_prompt_template(self, node_data: QuestionClassifierNodeData, query: str model_mode = ModelMode.value_of(node_data.model.mode) classes = node_data.classes class_names = [class_.name for class_ in classes] - class_names_str = ','.join(class_names) + class_names_str = ','.join(f'"{name}"' for name in class_names) instruction = node_data.instruction if node_data.instruction else '' input_text = query memory_str = '' diff --git a/api/core/workflow/nodes/question_classifier/template_prompts.py b/api/core/workflow/nodes/question_classifier/template_prompts.py index 829f0257bcbd5..318ad54f92f9e 100644 --- a/api/core/workflow/nodes/question_classifier/template_prompts.py +++ b/api/core/workflow/nodes/question_classifier/template_prompts.py @@ -18,7 +18,7 @@ QUESTION_CLASSIFIER_USER_PROMPT_1 = """ { "input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], - "categories": ["Customer Service, Satisfaction, Sales, Product"], + "categories": ["Customer Service", "Satisfaction", "Sales", "Product"], "classification_instructions": ["classify the text based on the feedback provided by customer"]}```JSON """ @@ -29,13 +29,13 @@ QUESTION_CLASSIFIER_USER_PROMPT_2 = """ {"input_text": ["bad service, slow to bring the food"], - "categories": ["Food Quality, Experience, Price" ], + "categories": ["Food Quality", "Experience", "Price" ], "classification_instructions": []}```JSON """ QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """ {"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"], - "categories": ["Experience""]}``` + "categories": ["Experience"]}``` """ QUESTION_CLASSIFIER_USER_PROMPT_3 = """ diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index aec8f34bb9d67..97fbe8a99922c 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -1,10 +1,9 @@ -from typing import Literal, Union +from typing import Any, Literal, Union from pydantic import BaseModel, validator from core.workflow.entities.base_node_data_entities import BaseNodeData -ToolParameterValue = Union[str, int, float, bool] class ToolEntity(BaseModel): provider_id: str @@ -12,11 +11,23 @@ class ToolEntity(BaseModel): provider_name: str # redundancy tool_name: str tool_label: str # redundancy - tool_configurations: dict[str, ToolParameterValue] + tool_configurations: dict[str, Any] + + @validator('tool_configurations', pre=True, always=True) + def validate_tool_configurations(cls, value, values): + if not isinstance(value, dict): + raise ValueError('tool_configurations must be a dictionary') + + for key in values.get('tool_configurations', {}).keys(): + value = values.get('tool_configurations', {}).get(key) + if not isinstance(value, str | int | float | bool): + raise ValueError(f'{key} must be a string') + + return value class ToolNodeData(BaseNodeData, ToolEntity): class ToolInput(BaseModel): - value: Union[ToolParameterValue, list[str]] + value: Union[Any, list[str]] type: Literal['mixed', 'variable', 'constant'] @validator('type', pre=True, always=True) @@ -25,12 +36,16 @@ def check_type(cls, value, values): value = values.get('value') if typ == 'mixed' and not isinstance(value, str): raise ValueError('value must be a string') - elif typ == 'variable' and not isinstance(value, list): - raise ValueError('value must be a list') - elif typ == 'constant' and not isinstance(value, ToolParameterValue): + elif typ == 'variable': + if not isinstance(value, list): + raise ValueError('value must be a list') + for val in value: + if not isinstance(val, str): + raise ValueError('value must be a list of strings') + elif typ == 'constant' and not isinstance(value, str | int | float | bool): raise ValueError('value must be a string, int, float, or bool') return typ - + """ Tool Node Schema """ diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 532d474258539..9390ffa2a4b67 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,9 +1,10 @@ import logging import time -from typing import Optional +from typing import Optional, cast +from core.app.app_config.entities import FileExtraConfig from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException -from core.file.file_obj import FileVar +from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool, VariableValue @@ -16,6 +17,7 @@ from core.workflow.nodes.http_request.http_request_node import HttpRequestNode from core.workflow.nodes.if_else.if_else_node import IfElseNode from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.llm.entities import LLMNodeData from core.workflow.nodes.llm.llm_node import LLMNode from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode from core.workflow.nodes.start.start_node import StartNode @@ -219,7 +221,8 @@ def single_step_run_workflow_node(self, workflow: Workflow, raise ValueError('node id not found in workflow graph') # Get node class - node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) + node_type = NodeType.value_of(node_config.get('data', {}).get('type')) + node_cls = node_classes.get(node_type) # init workflow run state node_instance = node_cls( @@ -252,11 +255,40 @@ def single_step_run_workflow_node(self, workflow: Workflow, variable_node_id = variable_selector[0] variable_key_list = variable_selector[1:] + # get value + value = user_inputs.get(variable_key) + + # temp fix for image type + if node_type == NodeType.LLM: + new_value = [] + if isinstance(value, list): + node_data = node_instance.node_data + node_data = cast(LLMNodeData, node_data) + + detail = node_data.vision.configs.detail if node_data.vision.configs else None + + for item in value: + if isinstance(item, dict) and 'type' in item and item['type'] == 'image': + transfer_method = FileTransferMethod.value_of(item.get('transfer_method')) + file = FileVar( + tenant_id=workflow.tenant_id, + type=FileType.IMAGE, + transfer_method=transfer_method, + url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, + related_id=item.get( + 'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, + extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None), + ) + new_value.append(file) + + if new_value: + value = new_value + # append variable and value to variable pool variable_pool.append_variable( node_id=variable_node_id, variable_key_list=variable_key_list, - value=user_inputs.get(variable_key) + value=value ) # run node node_run_result = node_instance.run( diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py index e0f3b8499058b..9a7c0deb20ac6 100644 --- a/api/events/event_handlers/__init__.py +++ b/api/events/event_handlers/__init__.py @@ -5,7 +5,6 @@ from .create_site_record_when_app_created import handle from .deduct_quota_when_messaeg_created import handle from .delete_installed_app_when_app_deleted import handle -from .generate_conversation_name_when_first_message_created import handle from .update_app_dataset_join_when_app_model_config_updated import handle from .update_provider_last_used_at_when_messaeg_created import handle from .update_app_dataset_join_when_app_published_workflow_updated import handle diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 0b281c9271374..68dae5a5537cd 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -29,7 +29,7 @@ def handle(sender, **kwargs): raise NotFound('Document not found') document.indexing_status = 'parsing' - document.processing_started_at = datetime.datetime.utcnow() + document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) documents.append(document) db.session.add(document) db.session.commit() diff --git a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py deleted file mode 100644 index 31535bf4ef68f..0000000000000 --- a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py +++ /dev/null @@ -1,32 +0,0 @@ -from core.llm_generator.llm_generator import LLMGenerator -from events.message_event import message_was_created -from extensions.ext_database import db -from models.model import AppMode - - -@message_was_created.connect -def handle(sender, **kwargs): - message = sender - conversation = kwargs.get('conversation') - is_first_message = kwargs.get('is_first_message') - extras = kwargs.get('extras', {}) - - auto_generate_conversation_name = True - if extras: - auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True) - - if auto_generate_conversation_name and is_first_message: - if conversation.mode != AppMode.COMPLETION.value: - app_model = conversation.app - if not app_model: - return - - # generate conversation name - try: - name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query) - conversation.name = name - except: - pass - - db.session.merge(conversation) - db.session.commit() diff --git a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py index ae983cc5d1a53..81cb86118082b 100644 --- a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py +++ b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timezone from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity from events.message_event import message_was_created @@ -17,5 +17,5 @@ def handle(sender, **kwargs): db.session.query(Provider).filter( Provider.tenant_id == application_generate_entity.app_config.tenant_id, Provider.provider_name == application_generate_entity.model_config.provider - ).update({'last_used': datetime.utcnow()}) + ).update({'last_used': datetime.now(timezone.utc).replace(tzinfo=None)}) db.session.commit() diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index fcb99a9e83b50..bd4755e76889e 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -28,6 +28,7 @@ def __call__(self, *args: object, **kwargs: object) -> object: celery_app.conf.update( result_backend=app.config["CELERY_RESULT_BACKEND"], + broker_connection_retry_on_startup=True, ) if app.config["BROKER_USE_SSL"]: diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index 943cf4f58d467..c2104cdf555f1 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -2,7 +2,7 @@ import shutil from collections.abc import Generator from contextlib import closing -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Union import boto3 @@ -38,7 +38,7 @@ def init_app(self, app: Flask): account_key=app.config.get('AZURE_BLOB_ACCOUNT_KEY'), resource_types=ResourceTypes(service=True, container=True, object=True), permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), - expiry=datetime.utcnow() + timedelta(hours=1) + expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1) ) self.client = BlobServiceClient(account_url=app.config.get('AZURE_BLOB_ACCOUNT_URL'), credential=sas_token) diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 55695195248c7..2cf023a39922b 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -1,6 +1,6 @@ import json -from langchain.schema import OutputParserException +from core.llm_generator.output_parser.errors import OutputParserException def parse_json_markdown(json_string: str) -> dict: diff --git a/api/models/model.py b/api/models/model.py index d34c577b5d5bd..df858e521991b 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -815,7 +815,7 @@ def files(self): @property def workflow_run(self): if self.workflow_run_id: - from api.models.workflow import WorkflowRun + from .workflow import WorkflowRun return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first() return None diff --git a/api/models/task.py b/api/models/task.py index 2a1bfa124fad6..618d831d8ed4e 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timezone from celery import states @@ -15,8 +15,8 @@ class CeleryTask(db.Model): task_id = db.Column(db.String(155), unique=True) status = db.Column(db.String(50), default=states.PENDING) result = db.Column(db.PickleType, nullable=True) - date_done = db.Column(db.DateTime, default=datetime.utcnow, - onupdate=datetime.utcnow, nullable=True) + date_done = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), + onupdate=lambda: datetime.now(timezone.utc).replace(tzinfo=None), nullable=True) traceback = db.Column(db.Text, nullable=True) name = db.Column(db.String(155), nullable=True) args = db.Column(db.LargeBinary, nullable=True) @@ -35,5 +35,5 @@ class CeleryTaskSet(db.Model): autoincrement=True, primary_key=True) taskset_id = db.Column(db.String(155), unique=True) result = db.Column(db.PickleType, nullable=True) - date_done = db.Column(db.DateTime, default=datetime.utcnow, + date_done = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), nullable=True) diff --git a/api/models/workflow.py b/api/models/workflow.py index 8db874b47100d..f65eba3637586 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -299,6 +299,10 @@ def message(self) -> Optional['Message']: Message.workflow_run_id == self.id ).first() + @property + def workflow(self): + return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() + class WorkflowNodeExecutionTriggeredFrom(Enum): """ diff --git a/api/requirements-dev.txt b/api/requirements-dev.txt new file mode 100644 index 0000000000000..2ac72f3797b64 --- /dev/null +++ b/api/requirements-dev.txt @@ -0,0 +1,4 @@ +coverage~=7.2.4 +pytest~=7.3.1 +pytest-mock~=3.11.1 +pytest-benchmark~=4.0.0 diff --git a/api/requirements.txt b/api/requirements.txt index 02d96a45a161c..39fdc1d94be4d 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -1,4 +1,3 @@ -coverage~=7.2.4 beautifulsoup4==4.12.2 flask~=3.0.1 Flask-SQLAlchemy~=3.0.5 @@ -10,15 +9,11 @@ flask-restful~=0.3.10 flask-cors~=4.0.0 gunicorn~=21.2.0 gevent~=23.9.1 -langchain==0.0.250 openai~=1.13.3 tiktoken~=0.6.0 psycopg2-binary~=2.9.6 pycryptodome==3.19.1 python-dotenv==1.0.0 -pytest~=7.3.1 -pytest-mock~=3.11.1 -pytest-benchmark~=4.0.0 Authlib==1.2.0 boto3==1.28.17 tenacity==8.2.2 @@ -34,20 +29,27 @@ redis[hiredis]~=5.0.3 openpyxl==3.1.2 chardet~=5.1.0 python-docx~=1.1.0 -pypdfium2==4.16.0 +pypdfium2 resend~=0.7.0 pyjwt~=2.8.0 anthropic~=0.23.1 newspaper3k==0.2.8 -google-api-python-client==2.90.0 wikipedia==1.4.0 readabilipy==0.2.0 +google-ai-generativelanguage==0.6.1 +google-api-core==2.18.0 +google-api-python-client==2.90.0 +google-auth==2.29.0 +google-auth-httplib2==0.2.0 +google-generativeai==0.5.0 google-search-results==2.4.2 +googleapis-common-protos==1.63.0 replicate~=0.22.0 websocket-client~=1.7.0 -dashscope[tokenizer]~=1.14.0 +dashscope[tokenizer]~=1.17.0 huggingface_hub~=0.16.4 -transformers~=4.31.0 +transformers~=4.35.0 +tokenizers~=0.15.0 pandas==1.5.3 xinference-client==0.9.4 safetensors==0.3.2 @@ -55,13 +57,12 @@ zhipuai==1.0.7 werkzeug~=3.0.1 pymilvus==2.3.0 qdrant-client==1.7.3 -cohere~=4.44 +cohere~=5.2.4 pyyaml~=6.0.1 numpy~=1.25.2 -unstructured[docx,pptx,msg,md,ppt]~=0.10.27 +unstructured[docx,pptx,msg,md,ppt,epub]~=0.10.27 bs4~=0.0.1 markdown~=3.5.1 -google-generativeai~=0.3.2 httpx[socks]~=0.24.1 matplotlib~=3.8.2 yfinance~=0.2.35 @@ -71,8 +72,11 @@ numexpr~=2.9.0 duckduckgo-search==5.2.2 arxiv==2.1.0 yarl~=1.9.4 -twilio==9.0.0 +twilio~=9.0.4 qrcode~=7.4.2 azure-storage-blob==12.9.0 azure-identity==1.15.0 -lxml==5.1.0 \ No newline at end of file +lxml==5.1.0 +xlrd~=2.0.1 +pydantic~=1.10.0 +pgvecto-rs==0.1.4 diff --git a/api/services/account_service.py b/api/services/account_service.py index 8b08a5f816ba6..1fe8da760c839 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -2,7 +2,7 @@ import logging import secrets import uuid -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from hashlib import sha256 from typing import Any, Optional @@ -59,8 +59,8 @@ def load_user(user_id: str) -> Account: available_ta.current = True db.session.commit() - if datetime.utcnow() - account.last_active_at > timedelta(minutes=10): - account.last_active_at = datetime.utcnow() + if datetime.now(timezone.utc).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10): + account.last_active_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() return account @@ -70,7 +70,7 @@ def load_user(user_id: str) -> Account: def get_account_jwt_token(account): payload = { "user_id": account.id, - "exp": datetime.utcnow() + timedelta(days=30), + "exp": datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(days=30), "iss": current_app.config['EDITION'], "sub": 'Console API Passport', } @@ -91,7 +91,7 @@ def authenticate(email: str, password: str) -> Account: if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.utcnow() + account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() if account.password is None or not compare_password(password, account.password, account.password_salt): @@ -163,7 +163,7 @@ def link_account_integrate(provider: str, open_id: str, account: Account) -> Non # If it exists, update the record account_integrate.open_id = open_id account_integrate.encrypted_token = "" # todo - account_integrate.updated_at = datetime.utcnow() + account_integrate.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) else: # If it does not exist, create a new record account_integrate = AccountIntegrate(account_id=account.id, provider=provider, open_id=open_id, @@ -197,7 +197,7 @@ def update_account(account, **kwargs): @staticmethod def update_last_login(account: Account, request) -> None: """Update last login time and ip""" - account.last_login_at = datetime.utcnow() + account.last_login_at = datetime.now(timezone.utc).replace(tzinfo=None) account.last_login_ip = get_remote_ip(request) db.session.add(account) db.session.commit() @@ -431,7 +431,7 @@ def register(cls, email, name, password: str = None, open_id: str = None, provid password=password ) account.status = AccountStatus.ACTIVE.value if not status else status.value - account.initialized_at = datetime.utcnow() + account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) if open_id is not None or provider is not None: AccountService.link_account_integrate(provider, open_id, account) diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index db4639d40b19d..8b00b28c4f75b 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -415,7 +415,7 @@ def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, raise NotFound("App annotation not found") annotation_setting.score_threshold = args['score_threshold'] annotation_setting.updated_user_id = current_user.id - annotation_setting.updated_at = datetime.datetime.utcnow() + annotation_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(annotation_setting) db.session.commit() diff --git a/api/services/app_service.py b/api/services/app_service.py index 673fa1e86b699..c2f7cbb02c424 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,6 +1,6 @@ import json import logging -from datetime import datetime +from datetime import datetime, timezone from typing import cast import yaml @@ -251,7 +251,7 @@ def update_app(self, app: App, args: dict) -> App: app.description = args.get('description', '') app.icon = args.get('icon') app.icon_background = args.get('icon_background') - app.updated_at = datetime.utcnow() + app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() return app @@ -264,7 +264,7 @@ def update_app_name(self, app: App, name: str) -> App: :return: App instance """ app.name = name - app.updated_at = datetime.utcnow() + app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() return app @@ -279,7 +279,7 @@ def update_app_icon(self, app: App, icon: str, icon_background: str) -> App: """ app.icon = icon app.icon_background = icon_background - app.updated_at = datetime.utcnow() + app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() return app @@ -295,7 +295,7 @@ def update_app_site_status(self, app: App, enable_site: bool) -> App: return app app.enable_site = enable_site - app.updated_at = datetime.utcnow() + app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() return app @@ -311,7 +311,7 @@ def update_app_api_status(self, app: App, enable_api: bool) -> App: return app app.enable_api = enable_api - app.updated_at = datetime.utcnow() + app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() return app diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 1a0213799e619..5c2fb83b7249e 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,5 +1,8 @@ from typing import Optional, Union +from sqlalchemy import or_ + +from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -13,8 +16,9 @@ class ConversationService: @classmethod def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int, - include_ids: Optional[list] = None, exclude_ids: Optional[list] = None, - exclude_debug_conversation: bool = False) -> InfiniteScrollPagination: + invoke_from: InvokeFrom, + include_ids: Optional[list] = None, + exclude_ids: Optional[list] = None) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) @@ -24,6 +28,7 @@ def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, End Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'), Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), Conversation.from_account_id == (user.id if isinstance(user, Account) else None), + or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value) ) if include_ids is not None: @@ -32,9 +37,6 @@ def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, End if exclude_ids is not None: base_query = base_query.filter(~Conversation.id.in_(exclude_ids)) - if exclude_debug_conversation: - base_query = base_query.filter(Conversation.override_model_configs == None) - if last_id: last_conversation = base_query.filter( Conversation.id == last_id, diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 44a48af58b559..fe95a22cacfc6 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -415,7 +415,7 @@ def pause_document(document): # update document to be paused document.is_paused = True document.paused_by = current_user.id - document.paused_at = datetime.datetime.utcnow() + document.paused_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() @@ -739,7 +739,7 @@ def update_document_with_dataset_id(dataset: Dataset, document_data: dict, document.parsing_completed_at = None document.cleaning_completed_at = None document.splitting_completed_at = None - document.updated_at = datetime.datetime.utcnow() + document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) document.created_from = created_from document.doc_form = document_data['doc_form'] db.session.add(document) @@ -1046,73 +1046,11 @@ def create_segment(cls, args: dict, document: Document, dataset: Dataset): credentials=embedding_model.credentials, texts=[content] ) - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == document.id - ).scalar() - segment_document = DocumentSegment( - tenant_id=current_user.current_tenant_id, - dataset_id=document.dataset_id, - document_id=document.id, - index_node_id=doc_id, - index_node_hash=segment_hash, - position=max_position + 1 if max_position else 1, - content=content, - word_count=len(content), - tokens=tokens, - status='completed', - indexing_at=datetime.datetime.utcnow(), - completed_at=datetime.datetime.utcnow(), - created_by=current_user.id - ) - if document.doc_form == 'qa_model': - segment_document.answer = args['answer'] - - db.session.add(segment_document) - db.session.commit() - - # save vector index - try: - VectorService.create_segments_vector([args['keywords']], [segment_document], dataset) - except Exception as e: - logging.exception("create segment index failed") - segment_document.enabled = False - segment_document.disabled_at = datetime.datetime.utcnow() - segment_document.status = 'error' - segment_document.error = str(e) - db.session.commit() - segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first() - return segment - - @classmethod - def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): - embedding_model = None - if dataset.indexing_technique == 'high_quality': - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model - ) - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == document.id - ).scalar() - pre_segment_data_list = [] - segment_data_list = [] - keywords_list = [] - for segment_item in segments: - content = segment_item['content'] - doc_id = str(uuid.uuid4()) - segment_hash = helper.generate_text_hash(content) - tokens = 0 - if dataset.indexing_technique == 'high_quality' and embedding_model: - # calc embedding use tokens - model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance) - tokens = model_type_instance.get_num_tokens( - model=embedding_model.model, - credentials=embedding_model.credentials, - texts=[content] - ) + lock_name = 'add_segment_lock_document_id_{}'.format(document.id) + with redis_client.lock(lock_name, timeout=600): + max_position = db.session.query(func.max(DocumentSegment.position)).filter( + DocumentSegment.document_id == document.id + ).scalar() segment_document = DocumentSegment( tenant_id=current_user.current_tenant_id, dataset_id=document.dataset_id, @@ -1124,30 +1062,96 @@ def multi_create_segment(cls, segments: list, document: Document, dataset: Datas word_count=len(content), tokens=tokens, status='completed', - indexing_at=datetime.datetime.utcnow(), - completed_at=datetime.datetime.utcnow(), + indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), created_by=current_user.id ) if document.doc_form == 'qa_model': - segment_document.answer = segment_item['answer'] - db.session.add(segment_document) - segment_data_list.append(segment_document) + segment_document.answer = args['answer'] - pre_segment_data_list.append(segment_document) - keywords_list.append(segment_item['keywords']) + db.session.add(segment_document) + db.session.commit() - try: # save vector index - VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset) - except Exception as e: - logging.exception("create segment index failed") - for segment_document in segment_data_list: + try: + VectorService.create_segments_vector([args['keywords']], [segment_document], dataset) + except Exception as e: + logging.exception("create segment index failed") segment_document.enabled = False - segment_document.disabled_at = datetime.datetime.utcnow() + segment_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) segment_document.status = 'error' segment_document.error = str(e) - db.session.commit() - return segment_data_list + db.session.commit() + segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first() + return segment + + @classmethod + def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): + lock_name = 'multi_add_segment_lock_document_id_{}'.format(document.id) + with redis_client.lock(lock_name, timeout=600): + embedding_model = None + if dataset.indexing_technique == 'high_quality': + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model + ) + max_position = db.session.query(func.max(DocumentSegment.position)).filter( + DocumentSegment.document_id == document.id + ).scalar() + pre_segment_data_list = [] + segment_data_list = [] + keywords_list = [] + for segment_item in segments: + content = segment_item['content'] + doc_id = str(uuid.uuid4()) + segment_hash = helper.generate_text_hash(content) + tokens = 0 + if dataset.indexing_technique == 'high_quality' and embedding_model: + # calc embedding use tokens + model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance) + tokens = model_type_instance.get_num_tokens( + model=embedding_model.model, + credentials=embedding_model.credentials, + texts=[content] + ) + segment_document = DocumentSegment( + tenant_id=current_user.current_tenant_id, + dataset_id=document.dataset_id, + document_id=document.id, + index_node_id=doc_id, + index_node_hash=segment_hash, + position=max_position + 1 if max_position else 1, + content=content, + word_count=len(content), + tokens=tokens, + status='completed', + indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + created_by=current_user.id + ) + if document.doc_form == 'qa_model': + segment_document.answer = segment_item['answer'] + db.session.add(segment_document) + segment_data_list.append(segment_document) + + pre_segment_data_list.append(segment_document) + keywords_list.append(segment_item['keywords']) + + try: + # save vector index + VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset) + except Exception as e: + logging.exception("create segment index failed") + for segment_document in segment_data_list: + segment_document.enabled = False + segment_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment_document.status = 'error' + segment_document.error = str(e) + db.session.commit() + return segment_data_list @classmethod def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset): @@ -1204,10 +1208,10 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document segment.word_count = len(content) segment.tokens = tokens segment.status = 'completed' - segment.indexing_at = datetime.datetime.utcnow() - segment.completed_at = datetime.datetime.utcnow() + segment.indexing_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.completed_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) segment.updated_by = current_user.id - segment.updated_at = datetime.datetime.utcnow() + segment.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) if document.doc_form == 'qa_model': segment.answer = args['answer'] db.session.add(segment) @@ -1217,7 +1221,7 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document except Exception as e: logging.exception("update segment index failed") segment.enabled = False - segment.disabled_at = datetime.datetime.utcnow() + segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) segment.status = 'error' segment.error = str(e) db.session.commit() diff --git a/api/services/file_service.py b/api/services/file_service.py index 53dd090236a77..f9df21f544b86 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -20,9 +20,10 @@ IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) -ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv'] -UNSTRUSTURED_ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', - 'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml'] +ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls', 'docx', 'csv'] +UNSTRUSTURED_ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls', + 'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml', 'epub'] + PREVIEW_WORDS_LIMIT = 3000 @@ -80,7 +81,7 @@ def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bo mime_type=file.mimetype, created_by_role=('account' if isinstance(user, Account) else 'end_user'), created_by=user.id, - created_at=datetime.datetime.utcnow(), + created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used=False, hash=hashlib.sha3_256(file_content).hexdigest() ) @@ -110,10 +111,10 @@ def upload_text(text: str, text_name: str) -> UploadFile: extension='txt', mime_type='text/plain', created_by=current_user.id, - created_at=datetime.datetime.utcnow(), + created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used=True, used_by=current_user.id, - used_at=datetime.datetime.utcnow() + used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) ) db.session.add(upload_file) diff --git a/api/services/vector_service.py b/api/services/vector_service.py index d336162baeecb..232d2943256ca 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -34,7 +34,7 @@ def create_segments_vector(cls, keywords_list: Optional[list[list[str]]], keyword = Keyword(dataset) if keywords_list and len(keywords_list) > 0: - keyword.add_texts(documents, keyword_list=keywords_list) + keyword.add_texts(documents, keywords_list=keywords_list) else: keyword.add_texts(documents) diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index 06e3f6fd53ccc..cba048ccdbbbf 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -1,5 +1,6 @@ from typing import Optional, Union +from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account @@ -11,8 +12,8 @@ class WebConversationService: @classmethod def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int, pinned: Optional[bool] = None, - exclude_debug_conversation: bool = False) -> InfiniteScrollPagination: + last_id: Optional[str], limit: int, invoke_from: InvokeFrom, + pinned: Optional[bool] = None) -> InfiniteScrollPagination: include_ids = None exclude_ids = None if pinned is not None: @@ -32,9 +33,9 @@ def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, End user=user, last_id=last_id, limit=limit, + invoke_from=invoke_from, include_ids=include_ids, exclude_ids=exclude_ids, - exclude_debug_conversation=exclude_debug_conversation ) @classmethod diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index d15c383b0b0ad..e1cffdd1bda3f 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,6 +1,6 @@ import json import time -from datetime import datetime +from datetime import datetime, timezone from typing import Optional from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager @@ -93,7 +93,7 @@ def sync_draft_workflow(self, app_model: App, workflow.graph = json.dumps(graph) workflow.features = json.dumps(features) workflow.updated_by = account.id - workflow.updated_at = datetime.utcnow() + workflow.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) # commit db session changes db.session.commit() @@ -123,7 +123,7 @@ def publish_workflow(self, app_model: App, tenant_id=app_model.tenant_id, app_id=app_model.id, type=draft_workflow.type, - version=str(datetime.utcnow()), + version=str(datetime.now(timezone.utc).replace(tzinfo=None)), graph=draft_workflow.graph, features=draft_workflow.features, created_by=account.id @@ -202,8 +202,8 @@ def run_draft_workflow_node(self, app_model: App, elapsed_time=time.perf_counter() - start_at, created_by_role=CreatedByRole.ACCOUNT.value, created_by=account.id, - created_at=datetime.utcnow(), - finished_at=datetime.utcnow() + created_at=datetime.now(timezone.utc).replace(tzinfo=None), + finished_at=datetime.now(timezone.utc).replace(tzinfo=None) ) db.session.add(workflow_node_execution) db.session.commit() @@ -230,8 +230,8 @@ def run_draft_workflow_node(self, app_model: App, elapsed_time=time.perf_counter() - start_at, created_by_role=CreatedByRole.ACCOUNT.value, created_by=account.id, - created_at=datetime.utcnow(), - finished_at=datetime.utcnow() + created_at=datetime.now(timezone.utc).replace(tzinfo=None), + finished_at=datetime.now(timezone.utc).replace(tzinfo=None) ) else: # create workflow node execution @@ -249,8 +249,8 @@ def run_draft_workflow_node(self, app_model: App, elapsed_time=time.perf_counter() - start_at, created_by_role=CreatedByRole.ACCOUNT.value, created_by=account.id, - created_at=datetime.utcnow(), - finished_at=datetime.utcnow() + created_at=datetime.now(timezone.utc).replace(tzinfo=None), + finished_at=datetime.now(timezone.utc).replace(tzinfo=None) ) db.session.add(workflow_node_execution) diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index a26ecf5526b9b..e0a1b219095b9 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -70,7 +70,7 @@ def add_document_to_index_task(dataset_document_id: str): except Exception as e: logging.exception("add document to index failed") dataset_document.enabled = False - dataset_document.disabled_at = datetime.datetime.utcnow() + dataset_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) dataset_document.status = 'error' dataset_document.error = str(e) db.session.commit() diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 666fa8692fb35..fda8b7a250f12 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -50,7 +50,7 @@ def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_ annotation_setting.score_threshold = score_threshold annotation_setting.collection_binding_id = dataset_collection_binding.id annotation_setting.updated_user_id = user_id - annotation_setting.updated_at = datetime.datetime.utcnow() + annotation_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(annotation_setting) else: new_app_annotation_setting = AppAnnotationSetting( diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 9e41c6aee31a0..d6dc970477cf8 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -85,9 +85,9 @@ def batch_create_segment_to_index_task(job_id: str, content: list, dataset_id: s word_count=len(content), tokens=tokens, created_by=user_id, - indexing_at=datetime.datetime.utcnow(), + indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), status='completed', - completed_at=datetime.datetime.utcnow() + completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) ) if dataset_document.doc_form == 'qa_model': segment_document.answer = segment['answer'] diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index b9737d7dddf5a..74d74ddf46279 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -46,16 +46,16 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, if documents is None or len(documents) == 0: logging.info(click.style('No documents found for dataset: {}'.format(dataset_id), fg='green')) - return + else: + logging.info(click.style('Cleaning documents for dataset: {}'.format(dataset_id), fg='green')) + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, None) - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, None) + for document in documents: + db.session.delete(document) - for document in documents: - db.session.delete(document) - - for segment in segments: - db.session.delete(segment) + for segment in segments: + db.session.delete(segment) db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete() db.session.query(DatasetQuery).filter(DatasetQuery.dataset_id == dataset_id).delete() diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index f33a1e91bf729..d31286e4cc4c4 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -38,7 +38,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] # update segment status to indexing update_params = { DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.utcnow() + DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) } DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.commit() @@ -75,7 +75,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] # update segment to completed update_params = { DocumentSegment.status: "completed", - DocumentSegment.completed_at: datetime.datetime.utcnow() + DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) } DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.commit() @@ -85,7 +85,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] except Exception as e: logging.exception("create segment to index failed") segment.enabled = False - segment.disabled_at = datetime.datetime.utcnow() + segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) segment.status = 'error' segment.error = str(e) db.session.commit() diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index a646158dbd35a..c35c18799a603 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -67,7 +67,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): # check the page is updated if last_edited_time != page_edited_time: document.indexing_status = 'parsing' - document.processing_started_at = datetime.datetime.utcnow() + document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() # delete all document segment and index diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index b7762070503bf..43d1cc13f90da 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -47,7 +47,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): if document: document.indexing_status = 'error' document.error = str(e) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() return @@ -62,7 +62,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): if document: document.indexing_status = 'parsing' - document.processing_started_at = datetime.datetime.utcnow() + document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) documents.append(document) db.session.add(document) db.session.commit() diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index e59c549a65ff0..b27274be37ee2 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -33,7 +33,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): raise NotFound('Document not found') document.indexing_status = 'parsing' - document.processing_started_at = datetime.datetime.utcnow() + document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() # delete all document segment and index diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index a6254a822d19c..e37c06855d47d 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -69,7 +69,7 @@ def enable_segment_to_index_task(segment_id: str): except Exception as e: logging.exception("enable segment to index failed") segment.enabled = False - segment.disabled_at = datetime.datetime.utcnow() + segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) segment.status = 'error' segment.error = str(e) db.session.commit() diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index dd1baa79d4ec9..9cd04b47641a7 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -65,9 +65,12 @@ JINA_API_KEY= # Ollama Credentials OLLAMA_BASE_URL= +# Together API Key +TOGETHER_API_KEY= + # Mock Switch MOCK_SWITCH=false # CODE EXECUTION CONFIGURATION CODE_EXECUTION_ENDPOINT= -CODE_EXECUTINO_API_KEY= \ No newline at end of file +CODE_EXECUTION_API_KEY= \ No newline at end of file diff --git a/api/tests/integration_tests/model_runtime/__mock/google.py b/api/tests/integration_tests/model_runtime/__mock/google.py index 4ac4dfe1f04dd..cc4d8c6fbdaa0 100644 --- a/api/tests/integration_tests/model_runtime/__mock/google.py +++ b/api/tests/integration_tests/model_runtime/__mock/google.py @@ -10,6 +10,7 @@ from google.generativeai.client import _ClientManager, configure from google.generativeai.types import GenerateContentResponse from google.generativeai.types.generation_types import BaseGenerateContentResponse +from google.ai.generativelanguage_v1beta.types import content as gag_content current_api_key = '' @@ -29,7 +30,7 @@ def __iter__(self): }), chunks=[] - ) + ) else: yield GenerateContentResponse( done=False, @@ -43,6 +44,14 @@ def __iter__(self): class MockGoogleResponseCandidateClass(object): finish_reason = 'stop' + @property + def content(self) -> gag_content.Content: + return gag_content.Content( + parts=[ + gag_content.Part(text='it\'s google!') + ] + ) + class MockGoogleClass(object): @staticmethod def generate_content_sync() -> GenerateContentResponse: diff --git a/api/tests/unit_tests/core/rag/__init__.py b/api/tests/unit_tests/core/rag/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/api/tests/unit_tests/core/rag/datasource/__init__.py b/api/tests/unit_tests/core/rag/datasource/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/__init__.py b/api/tests/unit_tests/core/rag/datasource/vdb/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/__init__.py b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py new file mode 100644 index 0000000000000..73257dd3382c7 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py @@ -0,0 +1,24 @@ +import pytest +from pydantic.error_wrappers import ValidationError + +from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig + + +def test_default_value(): + valid_config = { + 'host': 'localhost', + 'port': 19530, + 'user': 'root', + 'password': 'Milvus' + } + + for key in valid_config: + config = valid_config.copy() + del config[key] + with pytest.raises(ValidationError) as e: + MilvusConfig(**config) + assert e.value.errors()[1]['msg'] == f'config MILVUS_{key.upper()} is required' + + config = MilvusConfig(**valid_config) + assert config.secure is False + assert config.database == 'default' diff --git a/dev/reformat b/dev/reformat index 864f9b4b02b12..ebee1efb40afd 100755 --- a/dev/reformat +++ b/dev/reformat @@ -10,3 +10,11 @@ fi # run ruff linter ruff check --fix ./api + +# env files linting relies on `dotenv-linter` in path +if ! command -v dotenv-linter &> /dev/null; then + echo "Installing dotenv-linter ..." + pip install dotenv-linter +fi + +dotenv-linter ./api/.env.example ./web/.env.example diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index af12e86fe60dc..9646b5493154d 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -29,33 +29,35 @@ services: - "6379:6379" # The Weaviate vector store. - weaviate: - image: semitechnologies/weaviate:1.19.0 - restart: always - volumes: - # Mount the Weaviate data directory to the container. - - ./volumes/weaviate:/var/lib/weaviate - environment: - # The Weaviate configurations - # You can refer to the [Weaviate](https://weaviate.io/developers/weaviate/config-refs/env-vars) documentation for more information. - QUERY_DEFAULTS_LIMIT: 25 - AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'false' - PERSISTENCE_DATA_PATH: '/var/lib/weaviate' - DEFAULT_VECTORIZER_MODULE: 'none' - CLUSTER_HOSTNAME: 'node1' - AUTHENTICATION_APIKEY_ENABLED: 'true' - AUTHENTICATION_APIKEY_ALLOWED_KEYS: 'WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih' - AUTHENTICATION_APIKEY_USERS: 'hello@dify.ai' - AUTHORIZATION_ADMINLIST_ENABLED: 'true' - AUTHORIZATION_ADMINLIST_USERS: 'hello@dify.ai' - ports: - - "8080:8080" + # weaviate: + # image: semitechnologies/weaviate:1.19.0 + # restart: always + # volumes: + # # Mount the Weaviate data directory to the container. + # - ./volumes/weaviate:/var/lib/weaviate + # environment: + # # The Weaviate configurations + # # You can refer to the [Weaviate](https://weaviate.io/developers/weaviate/config-refs/env-vars) documentation for more information. + # QUERY_DEFAULTS_LIMIT: 25 + # AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'false' + # PERSISTENCE_DATA_PATH: '/var/lib/weaviate' + # DEFAULT_VECTORIZER_MODULE: 'none' + # CLUSTER_HOSTNAME: 'node1' + # AUTHENTICATION_APIKEY_ENABLED: 'true' + # AUTHENTICATION_APIKEY_ALLOWED_KEYS: 'WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih' + # AUTHENTICATION_APIKEY_USERS: 'hello@dify.ai' + # AUTHORIZATION_ADMINLIST_ENABLED: 'true' + # AUTHORIZATION_ADMINLIST_USERS: 'hello@dify.ai' + # ports: + # - "8080:8080" # The DifySandbox sandbox: image: langgenius/dify-sandbox:latest restart: always cap_add: + # Why is sys_admin permission needed? + # https://docs.dify.ai/getting-started/install-self-hosted/install-faq#id-16.-why-is-sys_admin-permission-needed - SYS_ADMIN environment: # The DifySandbox configurations diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 43d48d4949268..dccdaee2cca35 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.6.0 + image: langgenius/dify-api:0.6.3 restart: always environment: # Startup mode, 'api' starts the API server. @@ -150,7 +150,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.6.0 + image: langgenius/dify-api:0.6.3 restart: always environment: # Startup mode, 'worker' starts the Celery worker for processing the queue. @@ -223,6 +223,12 @@ services: # the api-key for resend (https://resend.com) RESEND_API_KEY: '' RESEND_API_URL: https://api.resend.com + # relyt configurations + RELYT_HOST: db + RELYT_PORT: 5432 + RELYT_USER: postgres + RELYT_PASSWORD: difyai123456 + RELYT_DATABASE: postgres depends_on: - db - redis @@ -232,7 +238,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.6.0 + image: langgenius/dify-web:0.6.3 restart: always environment: EDITION: SELF_HOSTED @@ -317,6 +323,8 @@ services: image: langgenius/dify-sandbox:latest restart: always cap_add: + # Why is sys_admin permission needed? + # https://docs.dify.ai/getting-started/install-self-hosted/install-faq#id-16.-why-is-sys_admin-permission-needed - SYS_ADMIN environment: # The DifySandbox configurations diff --git a/web/Dockerfile b/web/Dockerfile index 18cd3331ed44e..f2fc4af2f81e2 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -2,6 +2,9 @@ FROM node:20.11-alpine3.19 AS base LABEL maintainer="takatost@gmail.com" +# if you located in China, you can use aliyun mirror to speed up +# RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.aliyun.com/g' /etc/apk/repositories + RUN apk add --no-cache tzdata @@ -13,8 +16,10 @@ WORKDIR /app/web COPY package.json . COPY yarn.lock . -RUN yarn install --frozen-lockfile +# if you located in China, you can use taobao registry to speed up +# RUN yarn install --frozen-lockfile --registry https://registry.npmmirror.com/ +RUN yarn install --frozen-lockfile # build resources FROM base as builder diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index cb4e74cda7315..42bd05eec2e43 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -333,7 +333,7 @@ const AppInfo = ({ expand }: IAppInfoProps) => { )}/>
- {t('app.newApp.advanced')} + {showSwitchTip === 'chat' ? t('app.newApp.advanced') : t('app.types.workflow')} BETA
{t('app.newApp.advancedFor').toLocaleUpperCase()}
diff --git a/web/app/components/app/chat/thought/tool.tsx b/web/app/components/app/chat/thought/tool.tsx index 322df438b4885..324fcde7bcad9 100644 --- a/web/app/components/app/chat/thought/tool.tsx +++ b/web/app/components/app/chat/thought/tool.tsx @@ -19,7 +19,7 @@ type Props = { } const getIcon = (toolName: string, allToolIcons: Record) => { - if (toolName.startsWith('dataset-')) + if (toolName.startsWith('dataset_')) return const icon = allToolIcons[toolName] if (!icon) @@ -50,9 +50,9 @@ const Tool: FC = ({ }) => { const { t } = useTranslation() const { name, input, isFinished, output } = payload - const toolName = name.startsWith('dataset-') ? t('dataset.knowledge') : name + const toolName = name.startsWith('dataset_') ? t('dataset.knowledge') : name const [isShowDetail, setIsShowDetail] = useState(false) - const icon = getIcon(toolName, allToolIcons) as any + const icon = getIcon(name, allToolIcons) as any return (
diff --git a/web/app/components/app/configuration/config-var/config-modal/index.tsx b/web/app/components/app/configuration/config-var/config-modal/index.tsx index 671b5a0bd1990..58957acc4a078 100644 --- a/web/app/components/app/configuration/config-var/config-modal/index.tsx +++ b/web/app/components/app/configuration/config-var/config-modal/index.tsx @@ -17,7 +17,7 @@ import Switch from '@/app/components/base/switch' import { ChangeType, InputVarType } from '@/app/components/workflow/types' const TEXT_MAX_LENGTH = 256 -const PARAGRAPH_MAX_LENGTH = 1024 +const PARAGRAPH_MAX_LENGTH = 1032 * 32 export type IConfigModalProps = { isCreate?: boolean diff --git a/web/app/components/app/configuration/config/index.tsx b/web/app/components/app/configuration/config/index.tsx index 40a3980613089..b8bedba20ba1e 100644 --- a/web/app/components/app/configuration/config/index.tsx +++ b/web/app/components/app/configuration/config/index.tsx @@ -20,7 +20,8 @@ import ConfigContext from '@/context/debug-configuration' import ConfigPrompt from '@/app/components/app/configuration/config-prompt' import ConfigVar from '@/app/components/app/configuration/config-var' import { type CitationConfig, type ModelConfig, type ModerationConfig, type MoreLikeThisConfig, type PromptVariable, type SpeechToTextConfig, type SuggestedQuestionsAfterAnswerConfig, type TextToSpeechConfig } from '@/models/debug' -import { AppType, ModelModeType } from '@/types/app' +import type { AppType } from '@/types/app' +import { ModelModeType } from '@/types/app' import { useModalContext } from '@/context/modal-context' import ConfigParamModal from '@/app/components/app/configuration/toolbox/annotation/config-param-modal' import AnnotationFullModal from '@/app/components/billing/annotation-full/modal' @@ -60,7 +61,7 @@ const Config: FC = () => { moderationConfig, setModerationConfig, } = useContext(ConfigContext) - const isChatApp = mode === AppType.chat + const isChatApp = ['advanced-chat', 'agent-chat', 'chat'].includes(mode) const { data: speech2textDefaultModel } = useDefaultModel(ModelTypeEnum.speech2text) const { data: text2speechDefaultModel } = useDefaultModel(ModelTypeEnum.tts) const { setShowModerationSettingModal } = useModalContext() diff --git a/web/app/components/app/configuration/dataset-config/card-item/item.tsx b/web/app/components/app/configuration/dataset-config/card-item/item.tsx index ac221a81d4638..bc72b7d2998d8 100644 --- a/web/app/components/app/configuration/dataset-config/card-item/item.tsx +++ b/web/app/components/app/configuration/dataset-config/card-item/item.tsx @@ -66,7 +66,7 @@ const Item: FC = ({ ) } */}
-
+
setShowSettingsModal(true)} diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx index 6c6722a8ecf46..39a063182e123 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx @@ -121,7 +121,7 @@ const ChatItem: FC = ({ isResponding={isResponding} noChatInput noStopResponding - chatContainerclassName='p-4' + chatContainerClassName='p-4' chatFooterClassName='p-4 pb-0' suggestedQuestions={suggestedQuestions} onSend={doSend} diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/debug-item.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/debug-item.tsx index 43a782f610230..9f6da8a0988cf 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/debug-item.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/debug-item.tsx @@ -112,7 +112,7 @@ const DebugItem: FC = ({
{ - mode === 'chat' && currentProvider && currentModel && currentModel.status === ModelStatusEnum.active && ( + (mode === 'chat' || mode === 'agent-chat') && currentProvider && currentModel && currentModel.status === ModelStatusEnum.active && ( ) } diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx index 27348ecd5518e..892d0cfe8b329 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.tsx @@ -27,6 +27,7 @@ const DebugWithMultipleModel = () => { checkCanSend, } = useDebugWithMultipleModelContext() const { eventEmitter } = useEventEmitterContextContext() + const isChatMode = mode === 'chat' || mode === 'agent-chat' const handleSend = useCallback((message: string, files?: VisionFile[]) => { if (checkCanSend && !checkCanSend()) @@ -97,7 +98,7 @@ const DebugWithMultipleModel = () => { className={` grow mb-3 relative px-6 overflow-auto `} - style={{ height: mode === 'chat' ? 'calc(100% - 60px)' : '100%' }} + style={{ height: isChatMode ? 'calc(100% - 60px)' : '100%' }} > { multipleModelConfigs.map((modelConfig, index) => ( @@ -121,7 +122,7 @@ const DebugWithMultipleModel = () => { }
{ - mode === 'chat' && ( + isChatMode && (
{ 'hidden z-20 absolute right-[26px] top-[-158px] w-[376px] rounded-xl bg-white border-[0.5px] border-[rgba(0,0,0,0.05)] shadow-lg group-hover:block', )} > -
+
diff --git a/web/app/components/base/audio-btn/index.tsx b/web/app/components/base/audio-btn/index.tsx index 676b85414b4a9..8aee0b5f51343 100644 --- a/web/app/components/base/audio-btn/index.tsx +++ b/web/app/components/base/audio-btn/index.tsx @@ -1,11 +1,12 @@ 'use client' -import { useRef, useState } from 'react' +import { useEffect, useRef, useState } from 'react' import { t } from 'i18next' import { useParams, usePathname } from 'next/navigation' import s from './style.module.css' import Tooltip from '@/app/components/base/tooltip' import { randomString } from '@/utils' import { textToAudio } from '@/service/share' +import Loading from '@/app/components/base/loading' type AudioBtnProps = { value: string @@ -14,6 +15,8 @@ type AudioBtnProps = { isAudition?: boolean } +type AudioState = 'initial' | 'loading' | 'playing' | 'paused' | 'ended' + const AudioBtn = ({ value, voice, @@ -21,9 +24,8 @@ const AudioBtn = ({ isAudition, }: AudioBtnProps) => { const audioRef = useRef(null) - const [isPlaying, setIsPlaying] = useState(false) - const [isPause, setPause] = useState(false) - const [hasEnded, setHasEnded] = useState(false) + const [audioState, setAudioState] = useState('initial') + const selector = useRef(`play-tooltip-${randomString(4)}`) const params = useParams() const pathname = usePathname() @@ -34,9 +36,11 @@ const AudioBtn = ({ return '' } - const playAudio = async () => { + const loadAudio = async () => { const formData = new FormData() if (value !== '') { + setAudioState('loading') + formData.append('text', removeCodeBlocks(value)) formData.append('voice', removeCodeBlocks(voice)) @@ -59,67 +63,80 @@ const AudioBtn = ({ const blob_bytes = Buffer.from(audioResponse.data, 'latin1') const blob = new Blob([blob_bytes], { type: 'audio/wav' }) const audioUrl = URL.createObjectURL(blob) - const audio = new Audio(audioUrl) - audioRef.current = audio - audio.play().then(() => {}).catch(() => { - setIsPlaying(false) - URL.revokeObjectURL(audioUrl) - }) - audio.onended = () => { - setHasEnded(true) - setIsPlaying(false) - } + audioRef.current!.src = audioUrl } catch (error) { - setIsPlaying(false) + setAudioState('initial') console.error('Error playing audio:', error) } } } - const togglePlayPause = () => { + + const handleToggle = () => { + if (audioState === 'initial') + loadAudio() if (audioRef.current) { - if (isPlaying) { - if (!hasEnded) { - setPause(false) - audioRef.current.play() - } - if (!isPause) { - setPause(true) - audioRef.current.pause() - } + if (audioState === 'playing') { + audioRef.current.pause() + setAudioState('paused') } - else if (!isPlaying) { - if (isPause) { - setPause(false) - audioRef.current.play() - } - else { - setHasEnded(false) - playAudio().then() - } + else if (audioState === 'paused' || audioState === 'ended') { + audioRef.current.play() + setAudioState('playing') } - setIsPlaying(prevIsPlaying => !prevIsPlaying) - } - else { - setIsPlaying(true) - if (!isPlaying) - playAudio().then() } } + useEffect(() => { + const currentAudio = audioRef.current + const handleLoading = () => { + setAudioState('loading') + } + const handlePlay = () => { + currentAudio?.play() + setAudioState('playing') + } + const handleEnded = () => { + setAudioState('ended') + } + currentAudio?.addEventListener('progress', handleLoading) + currentAudio?.addEventListener('canplaythrough', handlePlay) + currentAudio?.addEventListener('ended', handleEnded) + return () => { + if (currentAudio) { + currentAudio.removeEventListener('progress', handleLoading) + currentAudio.removeEventListener('canplaythrough', handlePlay) + currentAudio.removeEventListener('ended', handleEnded) + URL.revokeObjectURL(currentAudio.src) + currentAudio.src = '' + } + } + }, []) + + const tooltipContent = { + initial: t('appApi.play'), + ended: t('appApi.play'), + paused: t('appApi.pause'), + playing: t('appApi.playing'), + loading: t('appApi.loading'), + }[audioState] + return ( -
+
-
-
-
+ onClick={handleToggle}> + {audioState === 'loading' &&
} + {audioState !== 'loading' &&
} +
+
) } diff --git a/web/app/components/base/audio-btn/style.module.css b/web/app/components/base/audio-btn/style.module.css index 7e3175aa139e2..b8a4da6b68d37 100644 --- a/web/app/components/base/audio-btn/style.module.css +++ b/web/app/components/base/audio-btn/style.module.css @@ -7,4 +7,4 @@ background-image: url(~@/app/components/develop/secret-key/assets/pause.svg); background-position: center; background-repeat: no-repeat; -} +} \ No newline at end of file diff --git a/web/app/components/base/auto-height-textarea/index.tsx b/web/app/components/base/auto-height-textarea/index.tsx index 760badebc277c..f1abbe3c578ed 100644 --- a/web/app/components/base/auto-height-textarea/index.tsx +++ b/web/app/components/base/auto-height-textarea/index.tsx @@ -1,5 +1,6 @@ import { forwardRef, useEffect, useRef } from 'react' import cn from 'classnames' +import { sleep } from '@/utils' type IProps = { placeholder?: string @@ -32,14 +33,13 @@ const AutoHeightTextarea = forwardRef( return false } - const focus = () => { + const focus = async () => { if (!doFocus()) { let hasFocus = false - const runId = setInterval(() => { - hasFocus = doFocus() - if (hasFocus) - clearInterval(runId) - }, 100) + await sleep(100) + hasFocus = doFocus() + if (!hasFocus) + focus() } } diff --git a/web/app/components/base/chat/chat/answer/index.tsx b/web/app/components/base/chat/chat/answer/index.tsx index c28093939894b..d338efc7e5d86 100644 --- a/web/app/components/base/chat/chat/answer/index.tsx +++ b/web/app/components/base/chat/chat/answer/index.tsx @@ -163,8 +163,8 @@ const Answer: FC = ({ } { - !!citation?.length && config?.retriever_resource?.enabled && !responding && ( - + !!citation?.length && !responding && ( + ) }
diff --git a/web/app/components/base/chat/chat/hooks.ts b/web/app/components/base/chat/chat/hooks.ts index 8aeaf1463abc1..4cdb6e8e38e77 100644 --- a/web/app/components/base/chat/chat/hooks.ts +++ b/web/app/components/base/chat/chat/hooks.ts @@ -229,7 +229,7 @@ export const useChat = ( // answer const responseItem: ChatItem = { - id: `${Date.now()}`, + id: placeholderAnswerId, content: '', agent_thoughts: [], message_files: [], diff --git a/web/app/components/base/chat/chat/index.tsx b/web/app/components/base/chat/chat/index.tsx index f0a2b345bf5f3..87332931f3d8a 100644 --- a/web/app/components/base/chat/chat/index.tsx +++ b/web/app/components/base/chat/chat/index.tsx @@ -11,6 +11,7 @@ import { } from 'react' import { useTranslation } from 'react-i18next' import { debounce } from 'lodash-es' +import classNames from 'classnames' import type { ChatConfig, ChatItem, @@ -36,7 +37,7 @@ export type ChatProps = { onStopResponding?: () => void noChatInput?: boolean onSend?: OnSend - chatContainerclassName?: string + chatContainerClassName?: string chatContainerInnerClassName?: string chatFooterClassName?: string chatFooterInnerClassName?: string @@ -60,7 +61,7 @@ const Chat: FC = ({ noStopResponding, onStopResponding, noChatInput, - chatContainerclassName, + chatContainerClassName, chatContainerInnerClassName, chatFooterClassName, chatFooterInnerClassName, @@ -171,7 +172,7 @@ const Chat: FC = ({
{chatNode}
= ({ + {/* */}
diff --git a/web/app/components/base/prompt-editor/plugins/on-blur-or-focus-block.tsx b/web/app/components/base/prompt-editor/plugins/on-blur-or-focus-block.tsx index 173d29b1cfe8e..2e3adc15cf937 100644 --- a/web/app/components/base/prompt-editor/plugins/on-blur-or-focus-block.tsx +++ b/web/app/components/base/prompt-editor/plugins/on-blur-or-focus-block.tsx @@ -40,7 +40,7 @@ const OnBlurBlock: FC = ({ () => { ref.current = setTimeout(() => { editor.dispatchCommand(KEY_ESCAPE_COMMAND, new KeyboardEvent('keydown', { key: 'Escape' })) - }, 100) + }, 200) if (onBlur) onBlur() diff --git a/web/app/components/base/prompt-editor/plugins/variable-value-block/index.tsx b/web/app/components/base/prompt-editor/plugins/variable-value-block/index.tsx index 60c7c2cc0dc49..e93c0d7f99f1f 100644 --- a/web/app/components/base/prompt-editor/plugins/variable-value-block/index.tsx +++ b/web/app/components/base/prompt-editor/plugins/variable-value-block/index.tsx @@ -31,7 +31,7 @@ const VariableValueBlock = () => { if (matchArr === null) return null - const hashtagLength = matchArr[3].length + 4 + const hashtagLength = matchArr[0].length const startOffset = matchArr.index const endOffset = startOffset + hashtagLength return { diff --git a/web/app/components/base/prompt-editor/plugins/variable-value-block/utils.ts b/web/app/components/base/prompt-editor/plugins/variable-value-block/utils.ts index f1e5d7d88a2ff..4d59d41031e72 100644 --- a/web/app/components/base/prompt-editor/plugins/variable-value-block/utils.ts +++ b/web/app/components/base/prompt-editor/plugins/variable-value-block/utils.ts @@ -1,5 +1,5 @@ export function getHashtagRegexString(): string { - const hashtag = '(\{)(\{)([a-zA-Z_][a-zA-Z0-9_]{0,29})(\})(\})' + const hashtag = '\\{\\{[a-zA-Z_][a-zA-Z0-9_]{0,29}\\}\\}' return hashtag } diff --git a/web/app/components/base/tag-input/index.tsx b/web/app/components/base/tag-input/index.tsx index d974280ec0e83..dc6dfa98a1e44 100644 --- a/web/app/components/base/tag-input/index.tsx +++ b/web/app/components/base/tag-input/index.tsx @@ -56,7 +56,9 @@ const TagInput: FC = ({ } onChange([...items, valueTrimed]) - setValue('') + setTimeout(() => { + setValue('') + }) } } diff --git a/web/app/components/datasets/create/embedding-process/index.tsx b/web/app/components/datasets/create/embedding-process/index.tsx index f15091832e694..5fb34018c6e39 100644 --- a/web/app/components/datasets/create/embedding-process/index.tsx +++ b/web/app/components/datasets/create/embedding-process/index.tsx @@ -1,11 +1,10 @@ import type { FC } from 'react' -import React, { useCallback, useEffect, useMemo } from 'react' +import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react' import useSWR from 'swr' import { useRouter } from 'next/navigation' import { useTranslation } from 'react-i18next' import { omit } from 'lodash-es' import { ArrowRightIcon } from '@heroicons/react/24/solid' -import { useGetState } from 'ahooks' import cn from 'classnames' import s from './index.module.css' import { FieldInfo } from '@/app/components/datasets/documents/detail/metadata' @@ -22,6 +21,7 @@ import UpgradeBtn from '@/app/components/billing/upgrade-btn' import { useProviderContext } from '@/context/provider-context' import TooltipPlus from '@/app/components/base/tooltip-plus' import { AlertCircle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' +import { sleep } from '@/utils' type Props = { datasetId: string @@ -89,37 +89,48 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index const getFirstDocument = documents[0] - const [indexingStatusBatchDetail, setIndexingStatusDetail, getIndexingStatusDetail] = useGetState([]) + const [indexingStatusBatchDetail, setIndexingStatusDetail] = useState([]) const fetchIndexingStatus = async () => { const status = await doFetchIndexingStatus({ datasetId, batchId }) setIndexingStatusDetail(status.data) + return status.data } - const [_, setRunId, getRunId] = useGetState>() - + const [isStopQuery, setIsStopQuery] = useState(false) + const isStopQueryRef = useRef(isStopQuery) + useEffect(() => { + isStopQueryRef.current = isStopQuery + }, [isStopQuery]) const stopQueryStatus = () => { - clearInterval(getRunId()) + setIsStopQuery(true) } - const startQueryStatus = () => { - const runId = setInterval(() => { - const indexingStatusBatchDetail = getIndexingStatusDetail() - const isCompleted = indexingStatusBatchDetail.every(indexingStatusDetail => ['completed', 'error'].includes(indexingStatusDetail.indexing_status)) + const startQueryStatus = async () => { + if (isStopQueryRef.current) + return + + try { + const indexingStatusBatchDetail = await fetchIndexingStatus() + const isCompleted = indexingStatusBatchDetail.every(indexingStatusDetail => ['completed', 'error', 'paused'].includes(indexingStatusDetail.indexing_status)) if (isCompleted) { stopQueryStatus() return } - fetchIndexingStatus() - }, 2500) - setRunId(runId) + await sleep(2500) + await startQueryStatus() + } + catch (e) { + await sleep(2500) + await startQueryStatus() + } } useEffect(() => { - fetchIndexingStatus() startQueryStatus() return () => { stopQueryStatus() } + // eslint-disable-next-line react-hooks/exhaustive-deps }, []) // get rule @@ -147,7 +158,7 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index return indexingStatusBatchDetail.some(indexingStatusDetail => ['indexing', 'splitting', 'parsing', 'cleaning'].includes(indexingStatusDetail?.indexing_status || '')) }, [indexingStatusBatchDetail]) const isEmbeddingCompleted = useMemo(() => { - return indexingStatusBatchDetail.every(indexingStatusDetail => ['completed', 'error'].includes(indexingStatusDetail?.indexing_status || '')) + return indexingStatusBatchDetail.every(indexingStatusDetail => ['completed', 'error', 'paused'].includes(indexingStatusDetail?.indexing_status || '')) }, [indexingStatusBatchDetail]) const getSourceName = (id: string) => { @@ -221,11 +232,11 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index indexingStatusDetail.indexing_status === 'completed' && s.success, )}> {isSourceEmbedding(indexingStatusDetail) && ( -
+
)}
{getSourceType(indexingStatusDetail.id) === DataSourceType.FILE && ( -
+
)} {getSourceType(indexingStatusDetail.id) === DataSourceType.NOTION && ( = ({ detail, stopPosition = 'top', datasetId: d const localDocumentId = docId ?? documentId const localIndexingTechnique = indexingType ?? indexingTechnique - const [indexingStatusDetail, setIndexingStatusDetail, getIndexingStatusDetail] = useGetState(null) + const [indexingStatusDetail, setIndexingStatusDetail] = useState(null) const fetchIndexingStatus = async () => { - try { - const status = await doFetchIndexingStatus({ datasetId: localDatasetId, documentId: localDocumentId }) - setIndexingStatusDetail(status) - // eslint-disable-next-line @typescript-eslint/no-use-before-define - startQueryStatus() - } - catch (err) { - // eslint-disable-next-line @typescript-eslint/no-use-before-define - stopQueryStatus() - notify({ type: 'error', message: `error: ${err}` }) - } + const status = await doFetchIndexingStatus({ datasetId: localDatasetId, documentId: localDocumentId }) + setIndexingStatusDetail(status) + return status } - const [runId, setRunId, getRunId] = useGetState(null) - + const [isStopQuery, setIsStopQuery] = useState(false) + const isStopQueryRef = useRef(isStopQuery) + useEffect(() => { + isStopQueryRef.current = isStopQuery + }, [isStopQuery]) const stopQueryStatus = () => { - clearInterval(getRunId()) + setIsStopQuery(true) } - const startQueryStatus = () => { - const runId = setInterval(() => { - const indexingStatusDetail = getIndexingStatusDetail() - if (indexingStatusDetail?.indexing_status === 'completed') { + const startQueryStatus = async () => { + if (isStopQueryRef.current) + return + + try { + const indexingStatusDetail = await fetchIndexingStatus() + if (['completed', 'error', 'paused'].includes(indexingStatusDetail?.indexing_status)) { stopQueryStatus() detailUpdate() return } - fetchIndexingStatus() - }, 2500) - setRunId(runId) + await sleep(2500) + await startQueryStatus() + } + catch (e) { + await sleep(2500) + await startQueryStatus() + } } useEffect(() => { - fetchIndexingStatus() + setIsStopQuery(false) + startQueryStatus() return () => { stopQueryStatus() } + // eslint-disable-next-line react-hooks/exhaustive-deps }, []) const { data: indexingEstimateDetail, error: indexingEstimateErr } = useSWR({ @@ -300,4 +303,4 @@ const EmbeddingDetail: FC = ({ detail, stopPosition = 'top', datasetId: d ) } -export default EmbeddingDetail +export default React.memo(EmbeddingDetail) diff --git a/web/app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx b/web/app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx index 9aedf0cb42a05..99acc7f40fd10 100644 --- a/web/app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx +++ b/web/app/components/header/account-setting/members-page/invited-modal/invitation-link.tsx @@ -18,7 +18,7 @@ const InvitationLink = ({ const selector = useRef(`invite-link-${randomString(4)}`) const copyHandle = useCallback(() => { - copy(value.url) + copy(window.location.origin + value.url) setIsCopied(true) }, [value]) diff --git a/web/app/components/header/account-setting/model-provider-page/declarations.ts b/web/app/components/header/account-setting/model-provider-page/declarations.ts index da8c69b69d307..e7d799ff9b4d7 100644 --- a/web/app/components/header/account-setting/model-provider-page/declarations.ts +++ b/web/app/components/header/account-setting/model-provider-page/declarations.ts @@ -1,8 +1,8 @@ export type FormValue = Record export type TypeWithI18N = { - 'en-US': T - 'zh-Hans': T + en_US: T + zh_Hans: T [key: string]: T } @@ -67,16 +67,16 @@ export enum ModelStatusEnum { export const MODEL_STATUS_TEXT: { [k: string]: TypeWithI18N } = { 'no-configure': { - 'en-US': 'No Configure', - 'zh-Hans': '未配置凭据', + en_US: 'No Configure', + zh_Hans: '未配置凭据', }, 'quota-exceeded': { - 'en-US': 'Quota Exceeded', - 'zh-Hans': '额度不足', + en_US: 'Quota Exceeded', + zh_Hans: '额度不足', }, 'no-permission': { - 'en-US': 'No Permission', - 'zh-Hans': '无使用权限', + en_US: 'No Permission', + zh_Hans: '无使用权限', }, } diff --git a/web/app/components/header/account-setting/model-provider-page/hooks.ts b/web/app/components/header/account-setting/model-provider-page/hooks.ts index 58c7dc906d399..27f2b15582f8f 100644 --- a/web/app/components/header/account-setting/model-provider-page/hooks.ts +++ b/web/app/components/header/account-setting/model-provider-page/hooks.ts @@ -11,11 +11,11 @@ import type { DefaultModel, DefaultModelResponse, Model, - ModelTypeEnum, } from './declarations' import { ConfigurateMethodEnum, + ModelStatusEnum, } from './declarations' import I18n from '@/context/i18n' import { @@ -132,6 +132,7 @@ export const useCurrentProviderAndModel = (modelList: Model[], defaultModel?: De export const useTextGenerationCurrentProviderAndModelAndModelList = (defaultModel?: DefaultModel) => { const { textGenerationModelList } = useProviderContext() + const activeTextGenerationModelList = textGenerationModelList.filter(model => model.status === ModelStatusEnum.active) const { currentProvider, currentModel, @@ -141,6 +142,7 @@ export const useTextGenerationCurrentProviderAndModelAndModelList = (defaultMode currentProvider, currentModel, textGenerationModelList, + activeTextGenerationModelList, } } diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx index c9ce37a7e58dd..b1aba24eae87f 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx @@ -93,7 +93,7 @@ const ModelParameterModal: FC = ({ const { currentProvider, currentModel, - textGenerationModelList, + activeTextGenerationModelList, } = useTextGenerationCurrentProviderAndModelAndModelList( { provider, model: modelId }, ) @@ -114,7 +114,7 @@ const ModelParameterModal: FC = ({ } const handleChangeModel = ({ provider, model }: DefaultModel) => { - const targetProvider = textGenerationModelList.find(modelItem => modelItem.provider === provider) + const targetProvider = activeTextGenerationModelList.find(modelItem => modelItem.provider === provider) const targetModelItem = targetProvider?.models.find(modelItem => modelItem.model === model) setModel({ modelId: model, @@ -223,7 +223,7 @@ const ModelParameterModal: FC = ({
diff --git a/web/app/components/share/text-generation/index.tsx b/web/app/components/share/text-generation/index.tsx index 58d9b42e4f699..1a88fcddd4f9a 100644 --- a/web/app/components/share/text-generation/index.tsx +++ b/web/app/components/share/text-generation/index.tsx @@ -174,7 +174,12 @@ const TextGeneration: FC = ({ promptConfig?.prompt_variables.forEach((v) => { res[v.name] = inputs[v.key] }) - res[t('share.generation.completionResult')] = batchCompletionResLatest[task.id] + let result = batchCompletionResLatest[task.id] + // task might return multiple fields, should marshal object to string + if (typeof batchCompletionResLatest[task.id] === 'object') + result = JSON.stringify(result) + + res[t('share.generation.completionResult')] = result return res }) const checkBatchInputs = (data: string[][]) => { diff --git a/web/app/components/share/text-generation/result/index.tsx b/web/app/components/share/text-generation/result/index.tsx index 6ccff5a2fdc6b..0a4a160be4b2c 100644 --- a/web/app/components/share/text-generation/result/index.tsx +++ b/web/app/components/share/text-generation/result/index.tsx @@ -17,6 +17,7 @@ import type { ModerationService } from '@/models/common' import { TransferMethod, type VisionFile, type VisionSettings } from '@/types/app' import { NodeRunningStatus, WorkflowRunningStatus } from '@/app/components/workflow/types' import type { WorkflowProcess } from '@/app/components/base/chat/types' +import { sleep } from '@/utils' export type IResultProps = { isWorkflow: boolean @@ -179,16 +180,16 @@ const Result: FC = ({ onShowRes() setRespondingTrue() - const startTime = Date.now() - let isTimeout = false - const runId = setInterval(() => { - if (Date.now() - startTime > 1000 * 60) { // 1min timeout - clearInterval(runId) + let isEnd = false + let isTimeout = false; + (async () => { + await sleep(1000 * 60) // 1min timeout + if (!isEnd) { setRespondingFalse() onCompleted(getCompletionRes(), taskId, false) isTimeout = true } - }, 1000) + })() if (isWorkflow) { sendWorkflowMessage( @@ -234,7 +235,7 @@ const Result: FC = ({ notify({ type: 'error', message: data.error }) setRespondingFalse() onCompleted(getCompletionRes(), taskId, false) - clearInterval(runId) + isEnd = true return } setWorkflowProccessData(produce(getWorkflowProccessData()!, (draft) => { @@ -249,7 +250,7 @@ const Result: FC = ({ setRespondingFalse() setMessageId(tempMessageId) onCompleted(getCompletionRes(), taskId, true) - clearInterval(runId) + isEnd = true }, }, isInstalledApp, @@ -269,7 +270,7 @@ const Result: FC = ({ setRespondingFalse() setMessageId(tempMessageId) onCompleted(getCompletionRes(), taskId, true) - clearInterval(runId) + isEnd = true }, onMessageReplace: (messageReplace) => { res = [messageReplace.answer] @@ -280,7 +281,7 @@ const Result: FC = ({ return setRespondingFalse() onCompleted(getCompletionRes(), taskId, false) - clearInterval(runId) + isEnd = true }, }, isInstalledApp, installedAppInfo?.id) } @@ -326,7 +327,7 @@ const Result: FC = ({ return (
{!isCallBatchAPI && ( - (isResponding && (!completionRes || !isWorkflow)) + (isResponding && !completionRes) ? (
diff --git a/web/app/components/workflow/custom-connection-line.tsx b/web/app/components/workflow/custom-connection-line.tsx index a411370b05f3f..c187f16fe1899 100644 --- a/web/app/components/workflow/custom-connection-line.tsx +++ b/web/app/components/workflow/custom-connection-line.tsx @@ -2,19 +2,20 @@ import { memo } from 'react' import type { ConnectionLineComponentProps } from 'reactflow' import { Position, - getSimpleBezierPath, + getBezierPath, } from 'reactflow' const CustomConnectionLine = ({ fromX, fromY, toX, toY }: ConnectionLineComponentProps) => { const [ edgePath, - ] = getSimpleBezierPath({ + ] = getBezierPath({ sourceX: fromX, sourceY: fromY, sourcePosition: Position.Right, targetX: toX, targetY: toY, targetPosition: Position.Left, + curvature: 0.16, }) return ( diff --git a/web/app/components/workflow/custom-edge.tsx b/web/app/components/workflow/custom-edge.tsx index 3575389c8e1c6..072aad8d54257 100644 --- a/web/app/components/workflow/custom-edge.tsx +++ b/web/app/components/workflow/custom-edge.tsx @@ -9,7 +9,7 @@ import { BaseEdge, EdgeLabelRenderer, Position, - getSimpleBezierPath, + getBezierPath, } from 'reactflow' import { useNodesExtraData, @@ -38,13 +38,14 @@ const CustomEdge = ({ edgePath, labelX, labelY, - ] = getSimpleBezierPath({ + ] = getBezierPath({ sourceX: sourceX - 8, sourceY, sourcePosition: Position.Right, targetX: targetX + 8, targetY, targetPosition: Position.Left, + curvature: 0.16, }) const [open, setOpen] = useState(false) const { handleNodeAdd } = useNodesInteractions() diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index 99c0cd224bfd6..ea9af3e9aa35e 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -43,7 +43,10 @@ export const useNodesInteractions = () => { const workflowStore = useWorkflowStore() const nodesExtraData = useNodesExtraData() const { handleSyncWorkflowDraft } = useNodesSyncDraft() - const { getAfterNodesInSameBranch } = useWorkflow() + const { + getAfterNodesInSameBranch, + getTreeLeafNodes, + } = useWorkflow() const { getNodesReadOnly } = useNodesReadOnly() const dragNodeStartPosition = useRef({ x: 0, y: 0 } as { x: number; y: number }) const connectingNodeRef = useRef<{ nodeId: string; handleType: HandleType } | null>(null) @@ -301,6 +304,8 @@ export const useNodesInteractions = () => { target, targetHandle, }) => { + if (source === target) + return if (getNodesReadOnly()) return @@ -311,6 +316,13 @@ export const useNodesInteractions = () => { setEdges, } = store.getState() const nodes = getNodes() + const targetNode = nodes.find(node => node.id === target!) + if (targetNode && targetNode?.data.type === BlockEnum.VariableAssigner) { + const treeNodes = getTreeLeafNodes(target!) + + if (!treeNodes.find(treeNode => treeNode.id === source)) + return + } const needDeleteEdges = edges.filter((edge) => { if (edge.source === source) { if (edge.sourceHandle) @@ -366,7 +378,7 @@ export const useNodesInteractions = () => { }) setEdges(newEdges) handleSyncWorkflowDraft() - }, [store, handleSyncWorkflowDraft, getNodesReadOnly]) + }, [store, handleSyncWorkflowDraft, getNodesReadOnly, getTreeLeafNodes]) const handleNodeConnectStart = useCallback((_, { nodeId, handleType }) => { if (nodeId && handleType) { @@ -703,6 +715,123 @@ export const useNodesInteractions = () => { handleSyncWorkflowDraft() }, [store, handleSyncWorkflowDraft, getNodesReadOnly, t]) + const handleNodeCopySelected = useCallback((): undefined | Node[] => { + if (getNodesReadOnly()) + return + + const { + setClipboardElements, + shortcutsDisabled, + } = workflowStore.getState() + + if (shortcutsDisabled) + return + + const { + getNodes, + } = store.getState() + + const nodes = getNodes() + const nodesToCopy = nodes.filter(node => node.data.selected) + + setClipboardElements(nodesToCopy) + + return nodesToCopy + }, [getNodesReadOnly, store, workflowStore]) + + const handleNodePaste = useCallback((): undefined | Node[] => { + if (getNodesReadOnly()) + return + + const { + clipboardElements, + shortcutsDisabled, + } = workflowStore.getState() + + if (shortcutsDisabled) + return + + const { + getNodes, + setNodes, + } = store.getState() + + const nodesToPaste: Node[] = [] + const nodes = getNodes() + + for (const nodeToPaste of clipboardElements) { + const nodeType = nodeToPaste.data.type + const nodesWithSameType = nodes.filter(node => node.data.type === nodeType) + + const newNode = generateNewNode({ + data: { + ...NODES_INITIAL_DATA[nodeType], + ...nodeToPaste.data, + _connectedSourceHandleIds: [], + _connectedTargetHandleIds: [], + title: nodesWithSameType.length > 0 ? `${t(`workflow.blocks.${nodeType}`)} ${nodesWithSameType.length + 1}` : t(`workflow.blocks.${nodeType}`), + selected: true, + }, + position: { + x: nodeToPaste.position.x + 10, + y: nodeToPaste.position.y + 10, + }, + }) + nodesToPaste.push(newNode) + } + + setNodes([...nodes.map((n: Node) => ({ ...n, selected: false, data: { ...n.data, selected: false } })), ...nodesToPaste]) + + handleSyncWorkflowDraft() + + return nodesToPaste + }, [getNodesReadOnly, handleSyncWorkflowDraft, store, t, workflowStore]) + + const handleNodeDuplicateSelected = useCallback(() => { + if (getNodesReadOnly()) + return + + handleNodeCopySelected() + handleNodePaste() + }, [getNodesReadOnly, handleNodeCopySelected, handleNodePaste]) + + const handleNodeCut = useCallback(() => { + if (getNodesReadOnly()) + return + + const nodesToCut = handleNodeCopySelected() + if (!nodesToCut) + return + + for (const node of nodesToCut) + handleNodeDelete(node.id) + }, [getNodesReadOnly, handleNodeCopySelected, handleNodeDelete]) + + const handleNodeDeleteSelected = useCallback(() => { + if (getNodesReadOnly()) + return + + const { + shortcutsDisabled, + } = workflowStore.getState() + + if (shortcutsDisabled) + return + + const { + getNodes, + } = store.getState() + + const nodes = getNodes() + const nodesToDelete = nodes.filter(node => node.data.selected) + + if (!nodesToDelete) + return + + for (const node of nodesToDelete) + handleNodeDelete(node.id) + }, [getNodesReadOnly, handleNodeDelete, store, workflowStore]) + return { handleNodeDragStart, handleNodeDrag, @@ -717,5 +846,10 @@ export const useNodesInteractions = () => { handleNodeDelete, handleNodeChange, handleNodeAdd, + handleNodeDuplicateSelected, + handleNodeCopySelected, + handleNodeCut, + handleNodeDeleteSelected, + handleNodePaste, } } diff --git a/web/app/components/workflow/hooks/use-workflow-run.ts b/web/app/components/workflow/hooks/use-workflow-run.ts index 67fce538e9530..6f6b4e4edaba2 100644 --- a/web/app/components/workflow/hooks/use-workflow-run.ts +++ b/web/app/components/workflow/hooks/use-workflow-run.ts @@ -176,6 +176,8 @@ export const useWorkflowRun = () => { const { getNodes, setNodes, + edges, + setEdges, } = store.getState() setWorkflowRunningData(produce(workflowRunningData!, (draft) => { draft.task_id = task_id @@ -192,6 +194,15 @@ export const useWorkflowRun = () => { }) }) setNodes(newNodes) + const newEdges = produce(edges, (draft) => { + draft.forEach((edge) => { + edge.data = { + ...edge.data, + _runned: false, + } + }) + }) + setEdges(newEdges) if (onWorkflowStarted) onWorkflowStarted(params) diff --git a/web/app/components/workflow/hooks/use-workflow.ts b/web/app/components/workflow/hooks/use-workflow.ts index 37feb1ae8d542..6864410c8abe7 100644 --- a/web/app/components/workflow/hooks/use-workflow.ts +++ b/web/app/components/workflow/hooks/use-workflow.ts @@ -160,8 +160,10 @@ export const useWorkflow = () => { if (incomers.length) { incomers.forEach((node) => { - callback(node) - traverse(node, callback) + if (!list.find(n => node.id === n.id)) { + callback(node) + traverse(node, callback) + } }) } } @@ -272,7 +274,10 @@ export const useWorkflow = () => { }, [isVarUsedInNodes]) const isValidConnection = useCallback(({ source, target }: Connection) => { - const { getNodes } = store.getState() + const { + edges, + getNodes, + } = store.getState() const nodes = getNodes() const sourceNode: Node = nodes.find(node => node.id === source)! const targetNode: Node = nodes.find(node => node.id === target)! @@ -287,7 +292,21 @@ export const useWorkflow = () => { return false } - return true + const hasCycle = (node: Node, visited = new Set()) => { + if (visited.has(node.id)) + return false + + visited.add(node.id) + + for (const outgoer of getOutgoers(node, nodes, edges)) { + if (outgoer.id === source) + return true + if (hasCycle(outgoer, visited)) + return true + } + } + + return !hasCycle(targetNode) }, [store, nodesExtraData]) const formatTimeFromNow = useCallback((time: number) => { @@ -312,6 +331,16 @@ export const useWorkflow = () => { return nodes.find(node => node.id === nodeId) || nodes.find(node => node.data.type === BlockEnum.Start) }, [store]) + const enableShortcuts = useCallback(() => { + const { setShortcutsDisabled } = workflowStore.getState() + setShortcutsDisabled(false) + }, [workflowStore]) + + const disableShortcuts = useCallback(() => { + const { setShortcutsDisabled } = workflowStore.getState() + setShortcutsDisabled(true) + }, [workflowStore]) + return { handleLayout, getTreeLeafNodes, @@ -326,6 +355,8 @@ export const useWorkflow = () => { renderTreeFromRecord, getNode, getBeforeNodeById, + enableShortcuts, + disableShortcuts, } } diff --git a/web/app/components/workflow/index.tsx b/web/app/components/workflow/index.tsx index 3bcd1cccd5433..fdd6d73fad8cd 100644 --- a/web/app/components/workflow/index.tsx +++ b/web/app/components/workflow/index.tsx @@ -113,6 +113,11 @@ const Workflow: FC = memo(({ handleNodeConnect, handleNodeConnectStart, handleNodeConnectEnd, + handleNodeDuplicateSelected, + handleNodeCopySelected, + handleNodeCut, + handleNodeDeleteSelected, + handleNodePaste, } = useNodesInteractions() const { handleEdgeEnter, @@ -120,7 +125,11 @@ const Workflow: FC = memo(({ handleEdgeDelete, handleEdgesChange, } = useEdgesInteractions() - const { isValidConnection } = useWorkflow() + const { + isValidConnection, + enableShortcuts, + disableShortcuts, + } = useWorkflow() useOnViewportChange({ onEnd: () => { @@ -128,7 +137,12 @@ const Workflow: FC = memo(({ }, }) - useKeyPress('Backspace', handleEdgeDelete) + useKeyPress(['delete'], handleEdgeDelete) + useKeyPress(['delete', 'backspace'], handleNodeDeleteSelected) + useKeyPress(['ctrl.c', 'meta.c'], handleNodeCopySelected) + useKeyPress(['ctrl.x', 'meta.x'], handleNodeCut) + useKeyPress(['ctrl.v', 'meta.v'], handleNodePaste) + useKeyPress(['ctrl.alt.d', 'meta.shift.d'], handleNodeDuplicateSelected) return (
= memo(({ edgeTypes={edgeTypes} nodes={nodes} edges={edges} + onPointerDown={enableShortcuts} + onMouseLeave={disableShortcuts} onNodeDragStart={handleNodeDragStart} onNodeDrag={handleNodeDrag} onNodeDragStop={handleNodeDragStop} diff --git a/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx b/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx index be392098dfd1e..fd7c02eabad66 100644 --- a/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx +++ b/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx @@ -12,6 +12,7 @@ import Split from '@/app/components/workflow/nodes/_base/components/split' import { InputVarType, NodeRunningStatus } from '@/app/components/workflow/types' import ResultPanel from '@/app/components/workflow/run/result-panel' import Toast from '@/app/components/base/toast' +import { TransferMethod } from '@/types/app' const i18nPrefix = 'workflow.singleRun' @@ -51,7 +52,18 @@ const BeforeRunForm: FC = ({ const isFinished = runningStatus === NodeRunningStatus.Succeeded || runningStatus === NodeRunningStatus.Failed const isRunning = runningStatus === NodeRunningStatus.Running + const isFileLoaded = (() => { + // system files + const filesForm = forms.find(item => !!item.values['#files#']) + if (!filesForm) + return true + const files = filesForm.values['#files#'] as any + if (files?.some((item: any) => item.transfer_method === TransferMethod.local_file && !item.upload_file_id)) + return false + + return true + })() const handleRun = useCallback(() => { let errMsg = '' forms.forEach((form) => { @@ -129,7 +141,7 @@ const BeforeRunForm: FC = ({
)} - diff --git a/web/app/components/workflow/nodes/_base/components/editor/base.tsx b/web/app/components/workflow/nodes/_base/components/editor/base.tsx index 779746d68ccb7..b76544ffe623e 100644 --- a/web/app/components/workflow/nodes/_base/components/editor/base.tsx +++ b/web/app/components/workflow/nodes/_base/components/editor/base.tsx @@ -49,7 +49,7 @@ const Base: FC = ({
{title}
-
+
e.stopPropagation()}> {headerRight} {!isCopied ? ( diff --git a/web/app/components/workflow/nodes/_base/components/variable/utils.ts b/web/app/components/workflow/nodes/_base/components/variable/utils.ts index fb12e366ff12d..969eeea976f01 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/utils.ts +++ b/web/app/components/workflow/nodes/_base/components/variable/utils.ts @@ -233,6 +233,16 @@ const matchNotSystemVars = (prompts: string[]) => { return uniqVars } +const replaceOldVarInText = (text: string, oldVar: ValueSelector, newVar: ValueSelector) => { + if (!text || typeof text !== 'string') + return text + + if (!newVar || newVar.length === 0) + return text + + return text.replaceAll(`{{#${oldVar.join('.')}#}}`, `{{#${newVar.join('.')}#}}`) +} + export const getNodeUsedVars = (node: Node): ValueSelector[] => { const { data } = node const { type } = data @@ -349,14 +359,21 @@ export const updateNodeVars = (oldNode: Node, oldVarSelector: ValueSelector, new } case BlockEnum.LLM: { const payload = data as LLMNodeType - // TODO: update in inputs - // if (payload.variables) { - // payload.variables = payload.variables.map((v) => { - // if (v.value_selector.join('.') === oldVarSelector.join('.')) - // v.value_selector = newVarSelector - // return v - // }) - // } + const isChatModel = payload.model?.mode === 'chat' + if (isChatModel) { + payload.prompt_template = (payload.prompt_template as PromptItem[]).map((prompt) => { + return { + ...prompt, + text: replaceOldVarInText(prompt.text, oldVarSelector, newVarSelector), + } + }) + } + else { + payload.prompt_template = { + ...payload.prompt_template, + text: replaceOldVarInText((payload.prompt_template as PromptItem).text, oldVarSelector, newVarSelector), + } + } if (payload.context?.variable_selector?.join('.') === oldVarSelector.join('.')) payload.context.variable_selector = newVarSelector @@ -408,30 +425,35 @@ export const updateNodeVars = (oldNode: Node, oldVarSelector: ValueSelector, new break } case BlockEnum.HttpRequest: { - // TODO: update in inputs - // const payload = data as HttpNodeType - // if (payload.variables) { - // payload.variables = payload.variables.map((v) => { - // if (v.value_selector.join('.') === oldVarSelector.join('.')) - // v.value_selector = newVarSelector - // return v - // }) - // } + const payload = data as HttpNodeType + payload.url = replaceOldVarInText(payload.url, oldVarSelector, newVarSelector) + payload.headers = replaceOldVarInText(payload.headers, oldVarSelector, newVarSelector) + payload.params = replaceOldVarInText(payload.params, oldVarSelector, newVarSelector) + payload.body.data = replaceOldVarInText(payload.body.data, oldVarSelector, newVarSelector) break } case BlockEnum.Tool: { - // TODO: update in inputs - // const payload = data as ToolNodeType - // if (payload.tool_parameters) { - // payload.tool_parameters = payload.tool_parameters.map((v) => { - // if (v.type === VarKindType.static) - // return v - - // if (v.value_selector?.join('.') === oldVarSelector.join('.')) - // v.value_selector = newVarSelector - // return v - // }) - // } + const payload = data as ToolNodeType + const hasShouldRenameVar = Object.keys(payload.tool_parameters)?.filter(key => payload.tool_parameters[key].type !== ToolVarType.constant) + if (hasShouldRenameVar) { + Object.keys(payload.tool_parameters).forEach((key) => { + const value = payload.tool_parameters[key] + const { type } = value + if (type === ToolVarType.variable) { + payload.tool_parameters[key] = { + ...value, + value: newVarSelector, + } + } + + if (type === ToolVarType.mixed) { + payload.tool_parameters[key] = { + ...value, + value: replaceOldVarInText(payload.tool_parameters[key].value as string, oldVarSelector, newVarSelector), + } + } + }) + } break } case BlockEnum.VariableAssigner: { diff --git a/web/app/components/workflow/nodes/_base/components/variable/var-reference-vars.tsx b/web/app/components/workflow/nodes/_base/components/variable/var-reference-vars.tsx index 93817ab1b2362..34fff7e6191a1 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/var-reference-vars.tsx +++ b/web/app/components/workflow/nodes/_base/components/variable/var-reference-vars.tsx @@ -208,18 +208,24 @@ const VarReferenceVars: FC = ({ const filteredVars = vars.filter((v) => { const children = v.vars.filter(v => checkKeys([v.variable], false).isValid || v.variable.startsWith('sys.')) return children.length > 0 - }).filter((v) => { + }).filter((node) => { if (!searchText) - return v - const children = v.vars.filter(v => v.variable.toLowerCase().includes(searchText.toLowerCase())) + return node + const children = node.vars.filter((v) => { + const searchTextLower = searchText.toLowerCase() + return v.variable.toLowerCase().includes(searchTextLower) || node.title.toLowerCase().includes(searchTextLower) + }) return children.length > 0 - }).map((v) => { - let vars = v.vars.filter(v => checkKeys([v.variable], false).isValid || v.variable.startsWith('sys.')) - if (searchText) - vars = vars.filter(v => v.variable.toLowerCase().includes(searchText.toLowerCase())) + }).map((node) => { + let vars = node.vars.filter(v => checkKeys([v.variable], false).isValid || v.variable.startsWith('sys.')) + if (searchText) { + const searchTextLower = searchText.toLowerCase() + if (!node.title.toLowerCase().includes(searchTextLower)) + vars = vars.filter(v => v.variable.toLowerCase().includes(searchText.toLowerCase())) + } return { - ...v, + ...node, vars, } }) @@ -266,7 +272,7 @@ const VarReferenceVars: FC = ({ } {filteredVars.length > 0 - ?
+ ?
{ filteredVars.map((item, i) => ( diff --git a/web/app/components/workflow/nodes/end/node.tsx b/web/app/components/workflow/nodes/end/node.tsx index c2e1b59608674..e083171a55256 100644 --- a/web/app/components/workflow/nodes/end/node.tsx +++ b/web/app/components/workflow/nodes/end/node.tsx @@ -79,8 +79,7 @@ const Node: FC> = ({
-
{getVarType(node?.id || '', value_selector)}
- +
{getVarType(node?.id || '', value_selector)}
) diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/node.tsx b/web/app/components/workflow/nodes/knowledge-retrieval/node.tsx index b32229818f740..3b8357873bdd4 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/node.tsx +++ b/web/app/components/workflow/nodes/knowledge-retrieval/node.tsx @@ -1,4 +1,4 @@ -import { type FC, useEffect, useState } from 'react' +import { type FC, useEffect, useRef, useState } from 'react' import React from 'react' import type { KnowledgeRetrievalNodeType } from './types' import { Folder } from '@/app/components/base/icons/src/vender/solid/files' @@ -10,10 +10,17 @@ const Node: FC> = ({ data, }) => { const [selectedDatasets, setSelectedDatasets] = useState([]) + const updateTime = useRef(0) useEffect(() => { (async () => { + updateTime.current = updateTime.current + 1 + const currUpdateTime = updateTime.current + if (data.dataset_ids?.length > 0) { const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: data.dataset_ids } }) + // avoid old data overwrite new data + if (currUpdateTime < updateTime.current) + return setSelectedDatasets(dataSetsWithDetail) } else { @@ -33,7 +40,7 @@ const Node: FC> = ({
-
+
{name}
diff --git a/web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx b/web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx new file mode 100644 index 0000000000000..a6d952545394f --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx @@ -0,0 +1,107 @@ +'use client' +import type { FC } from 'react' +import React, { useEffect, useState } from 'react' +import { uniqueId } from 'lodash-es' +import { useTranslation } from 'react-i18next' +import type { PromptItem } from '../../../types' +import Editor from '@/app/components/workflow/nodes/_base/components/prompt/editor' +import TypeSelector from '@/app/components/workflow/nodes/_base/components/selector' +import TooltipPlus from '@/app/components/base/tooltip-plus' +import { HelpCircle } from '@/app/components/base/icons/src/vender/line/general' +import { PromptRole } from '@/models/debug' + +const i18nPrefix = 'workflow.nodes.llm' + +type Props = { + readOnly: boolean + id: string + canRemove: boolean + isChatModel: boolean + isChatApp: boolean + payload: PromptItem + handleChatModeMessageRoleChange: (role: PromptRole) => void + onPromptChange: (p: string) => void + onRemove: () => void + isShowContext: boolean + hasSetBlockStatus: { + context: boolean + history: boolean + query: boolean + } + availableVars: any + availableNodes: any +} + +const roleOptions = [ + { + label: 'system', + value: PromptRole.system, + }, + { + label: 'user', + value: PromptRole.user, + }, + { + label: 'assistant', + value: PromptRole.assistant, + }, +] + +const ConfigPromptItem: FC = ({ + readOnly, + id, + canRemove, + handleChatModeMessageRoleChange, + isChatModel, + isChatApp, + payload, + onPromptChange, + onRemove, + isShowContext, + hasSetBlockStatus, + availableVars, + availableNodes, +}) => { + const { t } = useTranslation() + const [instanceId, setInstanceId] = useState(uniqueId()) + useEffect(() => { + setInstanceId(`${id}-${uniqueId()}`) + }, [id]) + return ( + + + + {t(`${i18nPrefix}.roleDescription.${payload.role}`)}
+ } + > + + +
+ } + value={payload.text} + onChange={onPromptChange} + readOnly={readOnly} + showRemove={canRemove} + onRemove={onRemove} + isChatModel={isChatModel} + isChatApp={isChatApp} + isShowContext={isShowContext} + hasSetBlockStatus={hasSetBlockStatus} + nodesOutputVars={availableVars} + availableNodes={availableNodes} + /> + ) +} +export default React.memo(ConfigPromptItem) diff --git a/web/app/components/workflow/nodes/llm/components/config-prompt.tsx b/web/app/components/workflow/nodes/llm/components/config-prompt.tsx index 7ffff9cb33241..d08fa000f3ef8 100644 --- a/web/app/components/workflow/nodes/llm/components/config-prompt.tsx +++ b/web/app/components/workflow/nodes/llm/components/config-prompt.tsx @@ -6,11 +6,9 @@ import produce from 'immer' import type { PromptItem, ValueSelector, Var } from '../../../types' import { PromptRole } from '../../../types' import useAvailableVarList from '../../_base/hooks/use-available-var-list' +import ConfigPromptItem from './config-prompt-item' import Editor from '@/app/components/workflow/nodes/_base/components/prompt/editor' import AddButton from '@/app/components/workflow/nodes/_base/components/add-button' -import TypeSelector from '@/app/components/workflow/nodes/_base/components/selector' -import TooltipPlus from '@/app/components/base/tooltip-plus' -import { HelpCircle } from '@/app/components/base/icons/src/vender/line/general' const i18nPrefix = 'workflow.nodes.llm' type Props = { @@ -58,21 +56,6 @@ const ConfigPrompt: FC = ({ } }, [onChange, payload]) - const roleOptions = [ - { - label: 'system', - value: PromptRole.system, - }, - { - label: 'user', - value: PromptRole.user, - }, - { - label: 'assistant', - value: PromptRole.assistant, - }, - ] - const handleChatModeMessageRoleChange = useCallback((index: number) => { return (role: PromptRole) => { const newPrompt = produce(payload as PromptItem[], (draft) => { @@ -84,6 +67,11 @@ const ConfigPrompt: FC = ({ const handleAddPrompt = useCallback(() => { const newPrompt = produce(payload as PromptItem[], (draft) => { + if (draft.length === 0) { + draft.push({ role: PromptRole.system, text: '' }) + + return + } const isLastItemUser = draft[draft.length - 1].role === PromptRole.user draft.push({ role: isLastItemUser ? PromptRole.assistant : PromptRole.user, text: '' }) }) @@ -117,37 +105,20 @@ const ConfigPrompt: FC = ({ { (payload as PromptItem[]).map((item, index) => { return ( - - - {t(`${i18nPrefix}.roleDescription.${item.role}`)}
- } - > - - -
- } - value={item.text} - onChange={handleChatModePromptChange(index)} + 1} readOnly={readOnly} - showRemove={(payload as PromptItem[]).length > 1} - onRemove={handleRemove(index)} + id={`${payload.length}-${index}`} + handleChatModeMessageRoleChange={handleChatModeMessageRoleChange(index)} isChatModel={isChatModel} isChatApp={isChatApp} + payload={item} + onPromptChange={handleChatModePromptChange(index)} + onRemove={handleRemove(index)} isShowContext={isShowContext} hasSetBlockStatus={hasSetBlockStatus} - nodesOutputVars={availableVars} + availableVars={availableVars} availableNodes={availableNodes} /> ) diff --git a/web/app/components/workflow/nodes/llm/components/resolution-picker.tsx b/web/app/components/workflow/nodes/llm/components/resolution-picker.tsx index 7b1526ffdd33c..2d0a39ba69676 100644 --- a/web/app/components/workflow/nodes/llm/components/resolution-picker.tsx +++ b/web/app/components/workflow/nodes/llm/components/resolution-picker.tsx @@ -23,7 +23,7 @@ const Item: FC = ({ title, value, onSelect, isSelected }) => { return (
{title} @@ -43,7 +43,7 @@ const ResolutionPicker: FC = ({ const { t } = useTranslation() return ( -
+
{t(`${i18nPrefix}.resolution.name`)}
> = ({ @@ -51,6 +52,7 @@ const Panel: FC> = ({ filterVar, handlePromptChange, handleMemoryChange, + handleVisionResolutionEnabledChange, handleVisionResolutionChange, isShowSingleRun, hideSingleRun, @@ -240,12 +242,19 @@ const Panel: FC> = ({ title={t(`${i18nPrefix}.vision`)} tooltip={t('appDebug.vision.description')!} operations={ - + } - /> + > + {inputs.vision.enabled + ? ( + + ) + : null} + + )}
diff --git a/web/app/components/workflow/nodes/llm/use-config.ts b/web/app/components/workflow/nodes/llm/use-config.ts index 5a3ebf7c54b8d..5efb49aa9d12d 100644 --- a/web/app/components/workflow/nodes/llm/use-config.ts +++ b/web/app/components/workflow/nodes/llm/use-config.ts @@ -103,6 +103,7 @@ const useConfig = (id: string, payload: LLMNodeType) => { // eslint-disable-next-line react-hooks/exhaustive-deps }, [defaultConfig, isChatModel]) + const [modelChanged, setModelChanged] = useState(false) const { currentProvider, currentModel, @@ -118,6 +119,7 @@ const useConfig = (id: string, payload: LLMNodeType) => { appendDefaultPromptConfig(draft, defaultConfig, model.mode === 'chat') }) setInputs(newInputs) + setModelChanged(true) }, [setInputs, defaultConfig, appendDefaultPromptConfig]) useEffect(() => { @@ -146,7 +148,35 @@ const useConfig = (id: string, payload: LLMNodeType) => { }, ) const isShowVisionConfig = !!currModel?.features?.includes(ModelFeatureEnum.vision) - + // change to vision model to set vision enabled, else disabled + useEffect(() => { + if (!modelChanged) + return + setModelChanged(false) + if (!isShowVisionConfig) { + const newInputs = produce(inputs, (draft) => { + draft.vision = { + enabled: false, + } + }) + setInputs(newInputs) + return + } + if (!inputs.vision?.enabled) { + const newInputs = produce(inputs, (draft) => { + if (!draft.vision?.enabled) { + draft.vision = { + enabled: true, + configs: { + detail: Resolution.high, + }, + } + } + }) + setInputs(newInputs) + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [isShowVisionConfig, modelChanged]) // variables const { handleVarListChange, handleAddVariable } = useVarList({ inputs, @@ -176,6 +206,28 @@ const useConfig = (id: string, payload: LLMNodeType) => { setInputs(newInputs) }, [inputs, setInputs]) + const handleVisionResolutionEnabledChange = useCallback((enabled: boolean) => { + const newInputs = produce(inputs, (draft) => { + if (!draft.vision) { + draft.vision = { + enabled, + configs: { + detail: Resolution.high, + }, + } + } + else { + draft.vision.enabled = enabled + if (!draft.vision.configs) { + draft.vision.configs = { + detail: Resolution.high, + } + } + } + }) + setInputs(newInputs) + }, [inputs, setInputs]) + const handleVisionResolutionChange = useCallback((newResolution: Resolution) => { const newInputs = produce(inputs, (draft) => { if (!draft.vision.configs) { @@ -296,6 +348,7 @@ const useConfig = (id: string, payload: LLMNodeType) => { filterVar, handlePromptChange, handleMemoryChange, + handleVisionResolutionEnabledChange, handleVisionResolutionChange, isShowSingleRun, hideSingleRun, diff --git a/web/app/components/workflow/nodes/start/use-config.ts b/web/app/components/workflow/nodes/start/use-config.ts index 8fe352554cb37..e30e8c283872c 100644 --- a/web/app/components/workflow/nodes/start/use-config.ts +++ b/web/app/components/workflow/nodes/start/use-config.ts @@ -28,11 +28,13 @@ const useConfig = (id: string, payload: StartNodeType) => { setFalse: hideRemoveVarConfirm, }] = useBoolean(false) const [removedVar, setRemovedVar] = useState([]) + const [removedIndex, setRemoveIndex] = useState(0) const handleVarListChange = useCallback((newList: InputVar[], moreInfo?: { index: number; payload: MoreInfo }) => { if (moreInfo?.payload?.type === ChangeType.remove) { if (isVarUsedInNodes([id, moreInfo?.payload?.payload?.beforeKey || ''])) { showRemoveVarConfirm() setRemovedVar([id, moreInfo?.payload?.payload?.beforeKey || '']) + setRemoveIndex(moreInfo?.index as number) return } } @@ -48,9 +50,13 @@ const useConfig = (id: string, payload: StartNodeType) => { }, [handleOutVarRenameChange, id, inputs, isVarUsedInNodes, setInputs, showRemoveVarConfirm]) const removeVarInNode = useCallback(() => { + const newInputs = produce(inputs, (draft) => { + draft.variables.splice(removedIndex, 1) + }) + setInputs(newInputs) removeUsedVarInNodes(removedVar) hideRemoveVarConfirm() - }, [hideRemoveVarConfirm, removeUsedVarInNodes, removedVar]) + }, [hideRemoveVarConfirm, inputs, removeUsedVarInNodes, removedIndex, removedVar, setInputs]) const handleAddVariable = useCallback((payload: InputVar) => { const newInputs = produce(inputs, (draft: StartNodeType) => { diff --git a/web/app/components/workflow/nodes/tool/use-config.ts b/web/app/components/workflow/nodes/tool/use-config.ts index 449a39bb9fd4d..5d499e7782796 100644 --- a/web/app/components/workflow/nodes/tool/use-config.ts +++ b/web/app/components/workflow/nodes/tool/use-config.ts @@ -1,4 +1,4 @@ -import { useCallback, useEffect, useState } from 'react' +import { useCallback, useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import produce from 'immer' import { useBoolean } from 'ahooks' @@ -25,7 +25,7 @@ const useConfig = (id: string, payload: ToolNodeType) => { const { t } = useTranslation() const language = useLanguage() - const { inputs, setInputs } = useNodeCrud(id, payload) + const { inputs, setInputs: doSetInputs } = useNodeCrud(id, payload) /* * tool_configurations: tool setting, not dynamic setting * tool_parameters: tool dynamic setting(by user) @@ -58,10 +58,41 @@ const useConfig = (id: string, payload: ToolNodeType) => { }, [currCollection?.name, hideSetAuthModal, t, handleFetchAllTools, provider_type]) const currTool = currCollection?.tools.find(tool => tool.name === tool_name) - const formSchemas = currTool ? toolParametersToFormSchemas(currTool.parameters) : [] + const formSchemas = useMemo(() => { + return currTool ? toolParametersToFormSchemas(currTool.parameters) : [] + }, [currTool]) const toolInputVarSchema = formSchemas.filter((item: any) => item.form === 'llm') // use setting const toolSettingSchema = formSchemas.filter((item: any) => item.form !== 'llm') + const hasShouldTransferTypeSettingInput = toolSettingSchema.some(item => item.type === 'boolean' || item.type === 'number-input') + + const setInputs = useCallback((value: ToolNodeType) => { + if (!hasShouldTransferTypeSettingInput) { + doSetInputs(value) + return + } + const newInputs = produce(value, (draft) => { + const newConfig = { ...draft.tool_configurations } + Object.keys(draft.tool_configurations).forEach((key) => { + const schema = formSchemas.find(item => item.variable === key) + const value = newConfig[key] + if (schema?.type === 'boolean') { + if (typeof value === 'string') + newConfig[key] = parseInt(value, 10) + + if (typeof value === 'boolean') + newConfig[key] = value ? 1 : 0 + } + + if (schema?.type === 'number-input') { + if (typeof value === 'string' && value !== '') + newConfig[key] = parseFloat(value) + } + }) + draft.tool_configurations = newConfig + }) + doSetInputs(newInputs) + }, [doSetInputs, formSchemas, hasShouldTransferTypeSettingInput]) const [notSetDefaultValue, setNotSetDefaultValue] = useState(false) const toolSettingValue = (() => { if (notSetDefaultValue) diff --git a/web/app/components/workflow/nodes/variable-assigner/components/var-list/index.tsx b/web/app/components/workflow/nodes/variable-assigner/components/var-list/index.tsx index 668a9fd012a99..36a34003fee2b 100644 --- a/web/app/components/workflow/nodes/variable-assigner/components/var-list/index.tsx +++ b/web/app/components/workflow/nodes/variable-assigner/components/var-list/index.tsx @@ -6,6 +6,7 @@ import produce from 'immer' import RemoveButton from '../../../_base/components/remove-button' import VarReferencePicker from '@/app/components/workflow/nodes/_base/components/variable/var-reference-picker' import type { ValueSelector, Var } from '@/app/components/workflow/types' +import { VarType as VarKindType } from '@/app/components/workflow/nodes/tool/types' type Props = { readonly: boolean @@ -71,6 +72,7 @@ const VarList: FC = ({ onOpen={handleOpen(index)} onlyLeafNodeVar={onlyLeafNodeVar} filterVar={filterVar} + defaultVarKindType={VarKindType.variable} /> {!readonly && ( > = (props) => { type={(node?.data.type as BlockEnum) || BlockEnum.Start} />
-
{node?.data.title}
+
{node?.data.title}
-
{varName}
+
{varName}
{/*
{output_type}
*/}
diff --git a/web/app/components/workflow/panel/chat-record/index.tsx b/web/app/components/workflow/panel/chat-record/index.tsx index d88ac6f943695..ecc653c431621 100644 --- a/web/app/components/workflow/panel/chat-record/index.tsx +++ b/web/app/components/workflow/panel/chat-record/index.tsx @@ -35,7 +35,7 @@ const ChatRecord = () => { content: item.answer, feedback: item.feedback, isAnswer: true, - citation: item.retriever_resources, + citation: item.metadata?.retriever_resources, message_files: item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [], workflow_run_id: item.workflow_run_id, }) @@ -82,9 +82,11 @@ const ChatRecord = () => {
((_, ref) => { } as any} chatList={chatList} isResponding={isResponding} - chatContainerclassName='px-4' + chatContainerClassName='px-4' chatContainerInnerClassName='pt-6' chatFooterClassName='px-4 rounded-bl-2xl' chatFooterInnerClassName='pb-4' diff --git a/web/app/components/workflow/panel/inputs-panel.tsx b/web/app/components/workflow/panel/inputs-panel.tsx index 4dc198a3e96b2..35a7ea2393934 100644 --- a/web/app/components/workflow/panel/inputs-panel.tsx +++ b/web/app/components/workflow/panel/inputs-panel.tsx @@ -16,6 +16,7 @@ import { } from '../store' import { useWorkflowRun } from '../hooks' import type { StartNodeType } from '../nodes/start/types' +import { TransferMethod } from '../../base/text-generation/types' import Button from '@/app/components/base/button' import { useFeatures } from '@/app/components/base/features/hooks' @@ -75,6 +76,13 @@ const InputsPanel = ({ onRun }: Props) => { handleRun({ inputs, files }) } + const canRun = (() => { + if (files?.some(item => (item.transfer_method as any) === TransferMethod.local_file && !item.upload_file_id)) + return false + + return true + })() + return ( <>
@@ -97,7 +105,7 @@ const InputsPanel = ({ onRun }: Props) => {