Skip to content

Commit

Permalink
add tests for connections
Browse files Browse the repository at this point in the history
  • Loading branch information
RuslanUC committed Mar 11, 2024
1 parent 5a16898 commit 5d393d6
Show file tree
Hide file tree
Showing 7 changed files with 417 additions and 105 deletions.
218 changes: 118 additions & 100 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ uvloop = "0.19.0"
async-timeout = "^4.0.3"
aerich = "^0.7.2"
yc-protobuf3-to-dict = "^0.3.0"
s3lite = "^0.1.3"
s3lite = "^0.1.4"
fast-depends = ">=2.1.1,<2.2.0"

[tool.poetry.group.dev.dependencies]
Expand All @@ -75,6 +75,7 @@ pytest-asyncio = "0.23.5"
pyftpdlib = "1.5.8"
fake-s3 = "1.0.2"
types-protobuf = "^4.24.0.4"
pytest-httpx = "^0.30.0"

[build-system]
requires = ["poetry-core"]
Expand Down
202 changes: 202 additions & 0 deletions tests/api/test_connections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import re
from os import urandom
from urllib import parse

import pytest as pt
import pytest_asyncio
from pytest_httpx import HTTPXMock

from tests.api.utils import TestClientType, create_users
from tests.httpx_mock_callbacks import github_oauth_token_exchange, github_oauth_user_get, reddit_oauth_token_exchange, \
reddit_oauth_user_get, twitch_oauth_token_exchange, spotify_oauth_token_exchange, twitch_oauth_user_get, \
spotify_oauth_user_get
from yepcord.rest_api.main import app
from yepcord.yepcord.config import Config


@pytest_asyncio.fixture(autouse=True)
async def setup_db():
for func in app.before_serving_funcs:
await app.ensure_async(func)()
yield
for func in app.after_serving_funcs:
await app.ensure_async(func)()


httpx_token_callbacks = {
"github": (github_oauth_token_exchange, {"url": re.compile(r'https://github.com/login/oauth/access_token?.+')}),
"reddit": (reddit_oauth_token_exchange, {"url": "https://www.reddit.com/api/v1/access_token"}),
"twitch": (twitch_oauth_token_exchange, {"url": "https://id.twitch.tv/oauth2/token"}),
"spotify": (spotify_oauth_token_exchange, {"url": "https://accounts.spotify.com/api/token"}),
}
httpx_user_callbacks = {
"github": (github_oauth_user_get, {"url": "https://api.github.com/user"}),
"reddit": (reddit_oauth_user_get, {"url": "https://oauth.reddit.com/api/v1/me"}),
"twitch": (twitch_oauth_user_get, {"url": "https://api.twitch.tv/helix/users"}),
"spotify": (spotify_oauth_user_get, {"url": "https://api.spotify.com/v1/me"}),
}


@pt.mark.parametrize("service_name", ["github", "reddit", "twitch", "spotify"])
@pt.mark.asyncio
async def test_connection(service_name: str, httpx_mock: HTTPXMock):
Config.update({"CONNECTIONS": {service_name: {"client_id": urandom(8).hex(), "client_secret": urandom(8).hex()}}})
code = urandom(8).hex()
access_token = urandom(8).hex()

client: TestClientType = app.test_client()
user = (await create_users(client, 1))[0]
headers = {"Authorization": user["token"]}

cb, kw = httpx_token_callbacks[service_name]
httpx_mock.add_callback(cb(**Config.CONNECTIONS[service_name], code=code, access_token=access_token), **kw)
cb, kw = httpx_user_callbacks[service_name]
httpx_mock.add_callback(cb(access_token=access_token), **kw)

resp = await client.get(f"/api/v9/connections/{service_name}/authorize", headers=headers)
assert resp.status_code == 200
j = await resp.get_json()
state = dict(parse.parse_qsl(parse.urlsplit(j["url"]).query))["state"]

resp = await client.post(f"/api/v9/connections/{service_name}/callback", headers=headers,
json={"code": code, "state": state, "insecure": False, "friend_sync": False})
assert resp.status_code == 204, await resp.get_json()

resp = await client.get("/api/v9/users/@me/connections", headers=headers)
assert resp.status_code == 200
j = await resp.get_json()
assert len(j) == 1


@pt.mark.asyncio
async def test_connection_wrong_state():
client: TestClientType = app.test_client()
user = (await create_users(client, 1))[0]
headers = {"Authorization": user["token"]}

state = "123.456"
resp = await client.post(f"/api/v9/connections/github/callback", headers=headers,
json={"code": "123456", "state": state, "insecure": False, "friend_sync": False})
assert resp.status_code == 204, await resp.get_json()

state = "abc-456"
resp = await client.post(f"/api/v9/connections/github/callback", headers=headers,
json={"code": "123456", "state": state, "insecure": False, "friend_sync": False})
assert resp.status_code == 204, await resp.get_json()

resp = await client.get("/api/v9/users/@me/connections", headers=headers)
assert resp.status_code == 200
j = await resp.get_json()
assert len(j) == 0


@pt.mark.asyncio
async def test_connection_wrong_code(httpx_mock: HTTPXMock):
Config.update({"CONNECTIONS": {"github": {"client_id": urandom(8).hex(), "client_secret": urandom(8).hex()}}})
code = urandom(8).hex()
access_token = urandom(8).hex()

client: TestClientType = app.test_client()
user = (await create_users(client, 1))[0]
headers = {"Authorization": user["token"]}

cb, kw = httpx_token_callbacks["github"]
httpx_mock.add_callback(cb(**Config.CONNECTIONS["github"], code=code, access_token=access_token), **kw)

resp = await client.get(f"/api/v9/connections/github/authorize", headers=headers)
assert resp.status_code == 200
j = await resp.get_json()
state = dict(parse.parse_qsl(parse.urlsplit(j["url"]).query))["state"]

resp = await client.post(f"/api/v9/connections/github/callback", headers=headers,
json={"code": code+"1", "state": state, "insecure": False, "friend_sync": False})
assert resp.status_code == 400

resp = await client.get("/api/v9/users/@me/connections", headers=headers)
assert resp.status_code == 200
j = await resp.get_json()
assert len(j) == 0


@pt.mark.asyncio
async def test_connection_add_same_account_twice(httpx_mock: HTTPXMock):
Config.update({"CONNECTIONS": {"github": {"client_id": urandom(8).hex(), "client_secret": urandom(8).hex()}}})
code = urandom(8).hex()
access_token = urandom(8).hex()

client: TestClientType = app.test_client()
user = (await create_users(client, 1))[0]
headers = {"Authorization": user["token"]}

cb, kw = httpx_token_callbacks["github"]
httpx_mock.add_callback(cb(**Config.CONNECTIONS["github"], code=code, access_token=access_token), **kw)
cb, kw = httpx_user_callbacks["github"]
httpx_mock.add_callback(cb(access_token=access_token), **kw)

for _ in range(2):
resp = await client.get(f"/api/v9/connections/github/authorize", headers=headers)
assert resp.status_code == 200
j = await resp.get_json()
state = dict(parse.parse_qsl(parse.urlsplit(j["url"]).query))["state"]

resp = await client.post(f"/api/v9/connections/github/callback", headers=headers,
json={"code": code, "state": state, "insecure": False, "friend_sync": False})
assert resp.status_code == 204, await resp.get_json()

resp = await client.get("/api/v9/users/@me/connections", headers=headers)
assert resp.status_code == 200
j = await resp.get_json()
assert len(j) == 1


@pt.mark.asyncio
async def test_connection_edit_delete(httpx_mock: HTTPXMock):
Config.update({"CONNECTIONS": {"github": {"client_id": urandom(8).hex(), "client_secret": urandom(8).hex()}}})
code = urandom(8).hex()
access_token = urandom(8).hex()

client: TestClientType = app.test_client()
user = (await create_users(client, 1))[0]
headers = {"Authorization": user["token"]}

cb, kw = httpx_token_callbacks["github"]
httpx_mock.add_callback(cb(**Config.CONNECTIONS["github"], code=code, access_token=access_token), **kw)
cb, kw = httpx_user_callbacks["github"]
httpx_mock.add_callback(cb(access_token=access_token), **kw)

resp = await client.get(f"/api/v9/connections/github/authorize", headers=headers)
assert resp.status_code == 200
j = await resp.get_json()
state = dict(parse.parse_qsl(parse.urlsplit(j["url"]).query))["state"]

resp = await client.post(f"/api/v9/connections/github/callback", headers=headers,
json={"code": code, "state": state, "insecure": False, "friend_sync": False})
assert resp.status_code == 204, await resp.get_json()

resp = await client.get("/api/v9/users/@me/connections", headers=headers)
assert resp.status_code == 200
j = await resp.get_json()
assert len(j) == 1

conn_id = j[0]["id"]

resp = await client.patch(f"/api/v9/users/@me/connections/github1/{conn_id}", headers=headers,
json={"visibility": False})
assert resp.status_code == 400

resp = await client.patch(f"/api/v9/users/@me/connections/github/{conn_id}1", headers=headers,
json={"visibility": False})
assert resp.status_code == 404

resp = await client.patch(f"/api/v9/users/@me/connections/github/{conn_id}", headers=headers,
json={"visibility": False})
assert resp.status_code == 200

resp = await client.delete(f"/api/v9/users/@me/connections/github1/{conn_id}", headers=headers)
assert resp.status_code == 400

resp = await client.delete(f"/api/v9/users/@me/connections/github/{conn_id}1", headers=headers)
assert resp.status_code == 404

resp = await client.delete(f"/api/v9/users/@me/connections/github/{conn_id}", headers=headers)
assert resp.status_code == 204
92 changes: 92 additions & 0 deletions tests/httpx_mock_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from httpx import Request, Response

from yepcord.yepcord.utils import b64decode


def github_oauth_token_exchange(client_id: str, client_secret: str, code: str, access_token: str):
def _github_oauth_token_exchange(request: Request) -> Response:
params = request.url.params
if params["client_id"] != client_id or params["client_secret"] != client_secret or params["code"] != code:
return Response(status_code=400, json={"error": ""})

return Response(status_code=200, json={"access_token": access_token})

return _github_oauth_token_exchange


def github_oauth_user_get(access_token: str):
def _github_oauth_user_get(request: Request) -> Response:
if request.headers["Authorization"] != f"Bearer {access_token}":
return Response(status_code=401, json={"error": ""})

return Response(status_code=200, json={"id": str(int(f"0x{access_token[:6]}", 16)), "login": access_token[:8]})

return _github_oauth_user_get


def reddit_oauth_token_exchange(client_id: str, client_secret: str, code: str, access_token: str):
def _reddit_oauth_token_exchange(request: Request) -> Response:
params = {k: v for k, v in [param.split("=") for param in request.content.decode("utf8").split("&")]}
client_id_, client_secret_ = b64decode(request.headers["Authorization"][6:]).decode("utf8").split(":")
if params["code"] != code or client_id_ != client_id or client_secret_ != client_secret:
return Response(status_code=400, json={"error": ""})

return Response(status_code=200, json={"access_token": access_token})

return _reddit_oauth_token_exchange


def reddit_oauth_user_get(access_token: str):
def _reddit_oauth_user_get(request: Request) -> Response:
if request.headers["Authorization"] != f"Bearer {access_token}":
return Response(status_code=401, json={"error": ""})

return Response(status_code=200, json={"id": str(int(f"0x{access_token[:6]}", 16)), "name": access_token[:8]})

return _reddit_oauth_user_get


def twitch_oauth_token_exchange(client_id: str, client_secret: str, code: str, access_token: str):
def _twitch_oauth_token_exchange(request: Request) -> Response:
params = {k: v for k, v in [param.split("=") for param in request.content.decode("utf8").split("&")]}
if params["code"] != code or params["client_id"] != client_id or params["client_secret"] != client_secret:
return Response(status_code=400, json={"error": ""})

return Response(status_code=200, json={"access_token": access_token})

return _twitch_oauth_token_exchange


def twitch_oauth_user_get(access_token: str):
def _twitch_oauth_user_get(request: Request) -> Response:
if request.headers["Authorization"] != f"Bearer {access_token}":
return Response(status_code=401, json={"error": ""})

return Response(status_code=200, json={"data": [
{"id": str(int(f"0x{access_token[:6]}", 16)), "login": access_token[:8]}
]})

return _twitch_oauth_user_get


def spotify_oauth_token_exchange(client_id: str, client_secret: str, code: str, access_token: str):
def _spotify_oauth_token_exchange(request: Request) -> Response:
params = {k: v for k, v in [param.split("=") for param in request.content.decode("utf8").split("&")]}
client_id_, client_secret_ = b64decode(request.headers["Authorization"][6:]).decode("utf8").split(":")
if params["code"] != code or client_id_ != client_id or client_secret_ != client_secret:
return Response(status_code=400, json={"error": ""})

return Response(status_code=200, json={"access_token": access_token})

return _spotify_oauth_token_exchange


def spotify_oauth_user_get(access_token: str):
def _spotify_oauth_user_get(request: Request) -> Response:
if request.headers["Authorization"] != f"Bearer {access_token}":
return Response(status_code=401, json={"error": ""})

return Response(status_code=200, json={"id": str(int(f"0x{access_token[:6]}", 16)),
"display_name": access_token[:8]})

return _spotify_oauth_user_get
3 changes: 1 addition & 2 deletions yepcord/rest_api/routes/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ async def unified_callback(connection_cls: type[BaseConnection], data: Connectio
if (conn := await connection_cls.get_connection_from_state(data.state)) is None:
return "", 204

if (access_token := await connection_cls.exchange_code(data.code)) is None:
return "", 204
access_token = await connection_cls.exchange_code(data.code)

user_info = await connection_cls.get_user_info(access_token)
if await ConnectedAccount.filter(type=connection_cls.SERVICE_NAME, service_id=user_info["id"]).exists():
Expand Down
2 changes: 1 addition & 1 deletion yepcord/rest_api/routes/users_me.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ async def update_protobuf_frecency_settings(data: SettingsProtoUpdate, user: Use

@users_me.get("/connections", oauth_scopes=["connections"])
async def get_connections(user: User = DepUser):
connections = await ConnectedAccount.filter(user=user)
connections = await ConnectedAccount.filter(user=user, verified=True)
return [conn.ds_json() for conn in connections]


Expand Down
2 changes: 1 addition & 1 deletion yepcord/yepcord/classes/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ async def get_user_info(cls, access_token: str) -> dict:
url, kwargs = cls.user_info_req(access_token)
async with AsyncClient() as cl:
resp = await cl.get(url, **kwargs)
if resp.status_code >= 400:
if resp.status_code >= 400: # pragma: no cover
raise InvalidDataErr(400, Errors.make(0))
return resp.json()

Expand Down

0 comments on commit 5d393d6

Please sign in to comment.