forked from GaiZhenbiao/ChuanhuChatGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'GaiZhenbiao:main' into main
- Loading branch information
Showing
14 changed files
with
448 additions
and
84 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import logging | ||
import os | ||
|
||
import fastapi | ||
import gradio | ||
from fastapi.responses import RedirectResponse | ||
from gradio.oauth import MOCKED_OAUTH_TOKEN | ||
|
||
from modules.presets import i18n | ||
|
||
OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID") | ||
OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET") | ||
OAUTH_SCOPES = os.environ.get("OAUTH_SCOPES") | ||
OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL") | ||
def _add_oauth_routes(app: fastapi.FastAPI) -> None: | ||
"""Add OAuth routes to the FastAPI app (login, callback handler and logout).""" | ||
try: | ||
from authlib.integrations.starlette_client import OAuth | ||
except ImportError as e: | ||
raise ImportError( | ||
"Cannot initialize OAuth to due a missing library. Please run `pip install gradio[oauth]` or add " | ||
"`gradio[oauth]` to your requirements.txt file in order to install the required dependencies." | ||
) from e | ||
|
||
# Check environment variables | ||
msg = ( | ||
"OAuth is required but {} environment variable is not set. Make sure you've enabled OAuth in your Space by" | ||
" setting `hf_oauth: true` in the Space metadata." | ||
) | ||
if OAUTH_CLIENT_ID is None: | ||
raise ValueError(msg.format("OAUTH_CLIENT_ID")) | ||
if OAUTH_CLIENT_SECRET is None: | ||
raise ValueError(msg.format("OAUTH_CLIENT_SECRET")) | ||
if OAUTH_SCOPES is None: | ||
raise ValueError(msg.format("OAUTH_SCOPES")) | ||
if OPENID_PROVIDER_URL is None: | ||
raise ValueError(msg.format("OPENID_PROVIDER_URL")) | ||
|
||
# Register OAuth server | ||
oauth = OAuth() | ||
oauth.register( | ||
name="huggingface", | ||
client_id=OAUTH_CLIENT_ID, | ||
client_secret=OAUTH_CLIENT_SECRET, | ||
client_kwargs={"scope": OAUTH_SCOPES}, | ||
server_metadata_url=OPENID_PROVIDER_URL + "/.well-known/openid-configuration", | ||
) | ||
|
||
# Define OAuth routes | ||
@app.get("/login/huggingface") | ||
async def oauth_login(request: fastapi.Request): | ||
"""Endpoint that redirects to HF OAuth page.""" | ||
redirect_uri = str(request.url_for("oauth_redirect_callback")) | ||
if ".hf.space" in redirect_uri: | ||
# In Space, FastAPI redirect as http but we want https | ||
redirect_uri = redirect_uri.replace("http://", "https://") | ||
return await oauth.huggingface.authorize_redirect(request, redirect_uri) | ||
|
||
@app.get("/login/callback") | ||
async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse: | ||
"""Endpoint that handles the OAuth callback.""" | ||
token = await oauth.huggingface.authorize_access_token(request) | ||
request.session["oauth_profile"] = token["userinfo"] | ||
request.session["oauth_token"] = token | ||
return RedirectResponse("/") | ||
|
||
@app.get("/logout") | ||
async def oauth_logout(request: fastapi.Request) -> RedirectResponse: | ||
"""Endpoint that logs out the user (e.g. delete cookie session).""" | ||
request.session.pop("oauth_profile", None) | ||
request.session.pop("oauth_token", None) | ||
# 清除cookie并跳转到首页 | ||
response = RedirectResponse(url="/", status_code=302) | ||
response.delete_cookie(key=f"access-token") | ||
response.delete_cookie(key=f"access-token-unsecure") | ||
return response | ||
|
||
|
||
def _add_mocked_oauth_routes(app: fastapi.FastAPI) -> None: | ||
"""Add fake oauth routes if Gradio is run locally and OAuth is enabled. | ||
Clicking on a gr.LoginButton will have the same behavior as in a Space (i.e. gets redirected in a new tab) but | ||
instead of authenticating with HF, a mocked user profile is added to the session. | ||
""" | ||
|
||
# Define OAuth routes | ||
@app.get("/login/huggingface") | ||
async def oauth_login(request: fastapi.Request): | ||
"""Fake endpoint that redirects to HF OAuth page.""" | ||
return RedirectResponse("/login/callback") | ||
|
||
@app.get("/login/callback") | ||
async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse: | ||
"""Endpoint that handles the OAuth callback.""" | ||
request.session["oauth_profile"] = MOCKED_OAUTH_TOKEN["userinfo"] | ||
request.session["oauth_token"] = MOCKED_OAUTH_TOKEN | ||
return RedirectResponse("/") | ||
|
||
@app.get("/logout") | ||
async def oauth_logout(request: fastapi.Request) -> RedirectResponse: | ||
"""Endpoint that logs out the user (e.g. delete cookie session).""" | ||
request.session.pop("oauth_profile", None) | ||
request.session.pop("oauth_token", None) | ||
# 清除cookie并跳转到首页 | ||
response = RedirectResponse(url="/", status_code=302) | ||
response.delete_cookie(key=f"access-token") | ||
response.delete_cookie(key=f"access-token-unsecure") | ||
return response | ||
|
||
|
||
def reg_patch(): | ||
gradio.oauth._add_mocked_oauth_routes = _add_mocked_oauth_routes | ||
gradio.oauth._add_oauth_routes = _add_oauth_routes | ||
logging.info(i18n("覆盖gradio.oauth /logout路由")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import json | ||
import logging | ||
import textwrap | ||
import uuid | ||
|
||
import google.generativeai as genai | ||
import gradio as gr | ||
import PIL | ||
import requests | ||
|
||
from modules.presets import i18n | ||
|
||
from ..index_func import construct_index | ||
from ..utils import count_token | ||
from .base_model import BaseLLMModel | ||
|
||
|
||
class GoogleGeminiClient(BaseLLMModel): | ||
def __init__(self, model_name, api_key, user_name="") -> None: | ||
super().__init__(model_name=model_name, user=user_name) | ||
self.api_key = api_key | ||
if "vision" in model_name.lower(): | ||
self.multimodal = True | ||
else: | ||
self.multimodal = False | ||
self.image_paths = [] | ||
|
||
def _get_gemini_style_input(self): | ||
self.history.extend([{"role": "image", "content": i} for i in self.image_paths]) | ||
self.image_paths = [] | ||
messages = [] | ||
for item in self.history: | ||
if item["role"] == "image": | ||
messages.append(PIL.Image.open(item["content"])) | ||
else: | ||
messages.append(item["content"]) | ||
return messages | ||
|
||
def to_markdown(self, text): | ||
text = text.replace("•", " *") | ||
return textwrap.indent(text, "> ", predicate=lambda _: True) | ||
|
||
def handle_file_upload(self, files, chatbot, language): | ||
if files: | ||
if self.multimodal: | ||
for file in files: | ||
if file.name: | ||
self.image_paths.append(file.name) | ||
chatbot = chatbot + [((file.name,), None)] | ||
return None, chatbot, None | ||
else: | ||
construct_index(self.api_key, file_src=files) | ||
status = i18n("索引构建完成") | ||
return gr.Files.update(), chatbot, status | ||
|
||
def get_answer_at_once(self): | ||
genai.configure(api_key=self.api_key) | ||
messages = self._get_gemini_style_input() | ||
model = genai.GenerativeModel(self.model_name) | ||
response = model.generate_content(messages) | ||
try: | ||
return self.to_markdown(response.text), len(response.text) | ||
except ValueError: | ||
return ( | ||
i18n("由于下面的原因,Google 拒绝返回 Gemini 的回答:\n\n") | ||
+ str(response.prompt_feedback), | ||
0, | ||
) | ||
|
||
def get_answer_stream_iter(self): | ||
genai.configure(api_key=self.api_key) | ||
messages = self._get_gemini_style_input() | ||
model = genai.GenerativeModel(self.model_name) | ||
response = model.generate_content(messages, stream=True) | ||
partial_text = "" | ||
for i in response: | ||
response = i.text | ||
partial_text += response | ||
yield partial_text | ||
self.all_token_counts[-1] = count_token(partial_text) | ||
yield partial_text |
Oops, something went wrong.