From adae368425e086f0f0980280e5c499e8dd8ac2f9 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Sun, 10 Mar 2024 17:52:41 +0200 Subject: [PATCH] add reddit --- STATUS.md | 2 +- config.example.py | 6 ++- yepcord/rest_api/routes/connections.py | 54 +++++++++++++++++++++++++- yepcord/yepcord/config.py | 5 ++- 4 files changed, 61 insertions(+), 6 deletions(-) diff --git a/STATUS.md b/STATUS.md index 8ccb3e0..1d1f7a5 100644 --- a/STATUS.md +++ b/STATUS.md @@ -46,7 +46,7 @@ - [x] Notes - [ ] Connections: - [ ] PayPal - - [ ] Reddit + - [x] Reddit - [ ] Steam - [ ] TikTok - [ ] Twitter diff --git a/config.example.py b/config.example.py index 41cad45..2865cdc 100644 --- a/config.example.py +++ b/config.example.py @@ -94,5 +94,9 @@ "github": { "client_id": None, "client_secret": None, - } + }, + "reddit": { + "client_id": None, + "client_secret": None, + }, } diff --git a/yepcord/rest_api/routes/connections.py b/yepcord/rest_api/routes/connections.py index a4e6574..2e0267a 100644 --- a/yepcord/rest_api/routes/connections.py +++ b/yepcord/rest_api/routes/connections.py @@ -15,7 +15,7 @@ You should have received a copy of the GNU Affero General Public License along with this program. If not, see . """ - +from base64 import b64encode from typing import Optional from urllib.parse import quote @@ -58,7 +58,7 @@ def parse_state(state: str) -> tuple[Optional[int], Optional[int]]: @connections.get("/github/authorize") async def connection_github_authorize(user: User = DepUser): client_id = get_service_settings("github", "client_id")["client_id"] - callback_url = quote(f"https://{Config.PUBLIC_HOST}/connections/github/callback") + callback_url = quote(f"https://{Config.PUBLIC_HOST}/connections/github/callback", safe="") conn, _ = await ConnectedAccount.get_or_create(user=user, type="github", verified=False) @@ -101,3 +101,53 @@ async def connection_github_callback(data: ConnectionCallback): await getGw().dispatch(UserConnectionsUpdate(conn), user_ids=[user_id]) return "", 204 + + +@connections.get("/reddit/authorize") +async def connection_reddit_authorize(user: User = DepUser): + client_id = get_service_settings("reddit", "client_id")["client_id"] + callback_url = quote(f"https://{Config.PUBLIC_HOST}/connections/reddit/callback", safe="") + + conn, _ = await ConnectedAccount.get_or_create(user=user, type="reddit", verified=False) + + url = (f"https://www.reddit.com/api/v1/authorize?client_id={client_id}&redirect_uri={callback_url}" + f"&scope=identity&state={user.id}.{conn.state}&response_type=code") + + return {"url": url} + + +@connections.post("/reddit/callback", body_cls=ConnectionCallback) +async def connection_reddit_callback(data: ConnectionCallback): + callback_url = quote(f"https://{Config.PUBLIC_HOST}/connections/reddit/callback", safe="") + settings = get_service_settings("reddit", "client_id") + client_id = settings["client_id"] + client_secret = settings["client_secret"] + user_id, state = parse_state(data.state) + if user_id is None: + return "", 204 + if (conn := await ConnectedAccount.get_or_none(user__id=user_id, state=state, verified=False, type="reddit")) \ + is None: + return "", 204 + + async with AsyncClient() as cl: + resp = await cl.post(f"https://www.reddit.com/api/v1/access_token", auth=(client_id, client_secret), + headers={"Accept": "application/json", "Content-Type": "application/x-www-form-urlencoded"}, + content=f"grant_type=authorization_code&code={data.code}&redirect_uri={callback_url}") + if resp.status_code >= 400 or "error" in (j := resp.json()): + raise InvalidDataErr(400, Errors.make(0)) + + access_token = j["access_token"] + + resp = await cl.get("https://oauth.reddit.com/api/v1/me", headers={"Authorization": f"Bearer {access_token}"}) + if resp.status_code >= 400: + raise InvalidDataErr(400, Errors.make(0)) + j = resp.json() + + if await ConnectedAccount.filter(type="reddit", service_id=j["id"]).exists(): + return "", 204 + + await conn.update(service_id=j["id"], name=j["name"], access_token=access_token, verified=True) + + await getGw().dispatch(UserConnectionsUpdate(conn), user_ids=[user_id]) + + return "", 204 diff --git a/yepcord/yepcord/config.py b/yepcord/yepcord/config.py index b2f584b..1ea42c9 100644 --- a/yepcord/yepcord/config.py +++ b/yepcord/yepcord/config.py @@ -106,13 +106,14 @@ class ConfigCaptcha(BaseModel): recaptcha: ConfigCaptchaService = Field(default_factory=ConfigCaptchaService) -class ConfigConnectionGithub(BaseModel): +class ConfigConnectionBase(BaseModel): client_id: Optional[str] = None client_secret: Optional[str] = None class ConfigConnections(BaseModel): - github: ConfigConnectionGithub = Field(default_factory=ConfigConnectionGithub) + github: ConfigConnectionBase = Field(default_factory=ConfigConnectionBase) + reddit: ConfigConnectionBase = Field(default_factory=ConfigConnectionBase) class ConfigModel(BaseModel):