Skip to content

Commit

Permalink
add twitch
Browse files Browse the repository at this point in the history
  • Loading branch information
RuslanUC committed Mar 10, 2024
1 parent adae368 commit d09bb09
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 102 deletions.
2 changes: 1 addition & 1 deletion STATUS.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
- [x] Github
- [ ] League of Legends
- [ ] Riot Games
- [ ] Twitch
- [x] Twitch
- [ ] YouTube
- [x] OAuth2
- [ ] Bots:
Expand Down
4 changes: 4 additions & 0 deletions config.example.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,8 @@
"client_id": None,
"client_secret": None,
},
"twitch": {
"client_id": None,
"client_secret": None,
},
}
1 change: 1 addition & 0 deletions yepcord/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def run_all(config: str, host: str, port: int, reload: bool, ssl: bool) -> None:
"forwarded_allow_ips": "'*'",
"host": host,
"port": port,
"timeout_graceful_shutdown": 1,
}

if reload:
Expand Down
125 changes: 24 additions & 101 deletions yepcord/rest_api/routes/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,139 +15,62 @@
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from base64 import b64encode
from typing import Optional
from urllib.parse import quote

from httpx import AsyncClient

from ..dependencies import DepUser
from ..models.connections import ConnectionCallback
from ..y_blueprint import YBlueprint
from ...gateway.events import UserConnectionsUpdate
from ...yepcord.config import Config
from ...yepcord.classes.connections import ConnectionGithub, ConnectionReddit, ConnectionTwitch, BaseConnection
from ...yepcord.ctx import getGw
from ...yepcord.errors import InvalidDataErr, Errors
from ...yepcord.models import User, ConnectedAccount

# Base path is /api/vX/connections
connections = YBlueprint("connections", __name__)


def get_service_settings(service_name: str, check_field: Optional[str] = None) -> dict:
settings = Config.CONNECTIONS[service_name]
if check_field is not None and settings[check_field] is None:
raise InvalidDataErr(400, Errors.make(50035, {"provider_id": {
"code": "BASE_TYPE_INVALID", "message": "This connection has been disabled server-side."
}}))

return settings
async def unified_callback(connection_cls: type[BaseConnection], data: ConnectionCallback,
user_login_field: str = "login"):
if (conn := await connection_cls.get_connection_from_state(data.state)) is None:
return "", 204

if (access_token := await connection_cls.exchange_code(data.code)) is None:
return "", 204

def parse_state(state: str) -> tuple[Optional[int], Optional[int]]:
state = state.split(".")
if len(state) != 2:
return None, None
user_id, real_state = state
if not user_id.isdigit() or not real_state.isdigit():
return None, None
user_info = await connection_cls.get_user_info(access_token)
if await ConnectedAccount.filter(type=connection_cls.SERVICE_NAME, service_id=user_info["id"]).exists():
return "", 204

return int(user_id), int(real_state)
await conn.update(service_id=user_info["id"], name=user_info[user_login_field], access_token=access_token,
verified=True)
await getGw().dispatch(UserConnectionsUpdate(conn), user_ids=[int(data.state.split(".")[0])])
return "", 204


@connections.get("/github/authorize")
async def connection_github_authorize(user: User = DepUser):
client_id = get_service_settings("github", "client_id")["client_id"]
callback_url = quote(f"https://{Config.PUBLIC_HOST}/connections/github/callback", safe="")

conn, _ = await ConnectedAccount.get_or_create(user=user, type="github", verified=False)

url = (f"https://github.com/login/oauth/authorize?client_id={client_id}&redirect_uri={callback_url}"
f"&scope=read%3Auser&state={user.id}.{conn.state}")

return {"url": url}
return {"url": await ConnectionGithub.authorize_url(user)}


@connections.post("/github/callback", body_cls=ConnectionCallback)
async def connection_github_callback(data: ConnectionCallback):
settings = get_service_settings("github", "client_id")
client_id = settings["client_id"]
client_secret = settings["client_secret"]
user_id, state = parse_state(data.state)
if user_id is None:
return "", 204
if (conn := await ConnectedAccount.get_or_none(user__id=user_id, state=state, verified=False, type="github")) \
is None:
return "", 204

async with AsyncClient() as cl:
resp = await cl.post(f"https://github.com/login/oauth/access_token?client_id={client_id}"
f"&client_secret={client_secret}&code={data.code}", headers={"Accept": "application/json"})
if resp.status_code >= 400 or "error" in (j := resp.json()):
raise InvalidDataErr(400, Errors.make(0))

access_token = j["access_token"]

resp = await cl.get("https://api.github.com/user", headers={"Authorization": f"Bearer {access_token}"})
if resp.status_code >= 400:
raise InvalidDataErr(400, Errors.make(0))
j = resp.json()

if await ConnectedAccount.filter(type="github", service_id=j["id"]).exists():
return "", 204

await conn.update(service_id=j["id"], name=j["login"], access_token=access_token, verified=True)

await getGw().dispatch(UserConnectionsUpdate(conn), user_ids=[user_id])

return "", 204
return await unified_callback(ConnectionGithub, data)


@connections.get("/reddit/authorize")
async def connection_reddit_authorize(user: User = DepUser):
client_id = get_service_settings("reddit", "client_id")["client_id"]
callback_url = quote(f"https://{Config.PUBLIC_HOST}/connections/reddit/callback", safe="")

conn, _ = await ConnectedAccount.get_or_create(user=user, type="reddit", verified=False)

url = (f"https://www.reddit.com/api/v1/authorize?client_id={client_id}&redirect_uri={callback_url}"
f"&scope=identity&state={user.id}.{conn.state}&response_type=code")

return {"url": url}
return {"url": await ConnectionReddit.authorize_url(user)}


@connections.post("/reddit/callback", body_cls=ConnectionCallback)
async def connection_reddit_callback(data: ConnectionCallback):
callback_url = quote(f"https://{Config.PUBLIC_HOST}/connections/reddit/callback", safe="")
settings = get_service_settings("reddit", "client_id")
client_id = settings["client_id"]
client_secret = settings["client_secret"]
user_id, state = parse_state(data.state)
if user_id is None:
return "", 204
if (conn := await ConnectedAccount.get_or_none(user__id=user_id, state=state, verified=False, type="reddit")) \
is None:
return "", 204
return await unified_callback(ConnectionReddit, data, "name")

async with AsyncClient() as cl:
resp = await cl.post(f"https://www.reddit.com/api/v1/access_token", auth=(client_id, client_secret),
headers={"Accept": "application/json", "Content-Type": "application/x-www-form-urlencoded"},
content=f"grant_type=authorization_code&code={data.code}&redirect_uri={callback_url}")
if resp.status_code >= 400 or "error" in (j := resp.json()):
raise InvalidDataErr(400, Errors.make(0))

access_token = j["access_token"]
@connections.get("/twitch/authorize")
async def connection_twitch_authorize(user: User = DepUser):
return {"url": await ConnectionTwitch.authorize_url(user)}

resp = await cl.get("https://oauth.reddit.com/api/v1/me", headers={"Authorization": f"Bearer {access_token}"})
if resp.status_code >= 400:
raise InvalidDataErr(400, Errors.make(0))
j = resp.json()

if await ConnectedAccount.filter(type="reddit", service_id=j["id"]).exists():
return "", 204

await conn.update(service_id=j["id"], name=j["name"], access_token=access_token, verified=True)

await getGw().dispatch(UserConnectionsUpdate(conn), user_ids=[user_id])

return "", 204
@connections.post("/twitch/callback", body_cls=ConnectionCallback)
async def connection_twitch_callback(data: ConnectionCallback):
return await unified_callback(ConnectionTwitch, data)
155 changes: 155 additions & 0 deletions yepcord/yepcord/classes/connections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from abc import ABC, abstractmethod
from typing import Optional
from urllib.parse import quote

from httpx import AsyncClient

from yepcord.yepcord.config import Config
from yepcord.yepcord.errors import InvalidDataErr, Errors
from yepcord.yepcord.models import User, ConnectedAccount


def get_service_settings(service_name: str, check_field: Optional[str] = None) -> dict:
settings = Config.CONNECTIONS[service_name]
if check_field is not None and settings[check_field] is None:
raise InvalidDataErr(400, Errors.make(50035, {"provider_id": {
"code": "BASE_TYPE_INVALID", "message": "This connection has been disabled server-side."
}}))

return settings


def parse_state(state: str) -> tuple[Optional[int], Optional[int]]:
state = state.split(".")
if len(state) != 2:
return None, None
user_id, real_state = state
if not user_id.isdigit() or not real_state.isdigit():
return None, None

return int(user_id), int(real_state)


class BaseConnection(ABC):
SERVICE_NAME = ""
AUTHORIZE_URL = ""
TOKEN_URL = ""
USER_URL = ""
SCOPE: list[str] = []

@classmethod
async def authorize_url(cls, user: User) -> str:
client_id = get_service_settings(cls.SERVICE_NAME, "client_id")["client_id"]
callback_url = quote(f"https://{Config.PUBLIC_HOST}/connections/{cls.SERVICE_NAME}/callback", safe="")

conn, _ = await ConnectedAccount.get_or_create(user=user, type=cls.SERVICE_NAME, verified=False)

scope = quote(" ".join(cls.SCOPE))
return (f"{cls.AUTHORIZE_URL}?client_id={client_id}&redirect_uri={callback_url}&scope={scope}"
f"&state={user.id}.{conn.state}")

@classmethod
@abstractmethod
def exchange_code_req(cls, code: str, settings: dict[str, str]) -> tuple[str, dict]: ...

@classmethod
async def get_connection_from_state(cls, state: str) -> Optional[ConnectedAccount]:
user_id, state = parse_state(state)
if user_id is None:
return
return await ConnectedAccount.get_or_none(user__id=user_id, state=state, verified=False, type=cls.SERVICE_NAME)

@classmethod
async def exchange_code(cls, code: str) -> Optional[str]:
settings = get_service_settings(cls.SERVICE_NAME, "client_id")

async with AsyncClient() as cl:
url, kwargs = cls.exchange_code_req(code, settings)
resp = await cl.post(url, **kwargs)
if resp.status_code >= 400 or "error" in (j := resp.json()):
raise InvalidDataErr(400, Errors.make(0))

return j["access_token"]

@classmethod
def user_info_req(cls, access_token: str) -> tuple[str, dict]:
return cls.USER_URL, {"headers": {"Authorization": f"Bearer {access_token}"}}

@classmethod
async def get_user_info(cls, access_token: str) -> dict:
url, kwargs = cls.user_info_req(access_token)
async with AsyncClient() as cl:
resp = await cl.get(url, **kwargs)
if resp.status_code >= 400:
raise InvalidDataErr(400, Errors.make(0))
return resp.json()


class ConnectionGithub(BaseConnection):
SERVICE_NAME = "github"
AUTHORIZE_URL = "https://github.com/login/oauth/authorize"
TOKEN_URL = "https://github.com/login/oauth/access_token"
USER_URL = "https://api.github.com/user"
SCOPE: list[str] = ["read:user"]

@classmethod
def exchange_code_req(cls, code: str, settings: dict[str, str]) -> tuple[str, dict]:
url = f"{cls.TOKEN_URL}?client_id={settings['client_id']}&client_secret={settings['client_secret']}&code={code}"
kwargs = {"headers": {"Accept": "application/json"}}

return url, kwargs


class ConnectionReddit(BaseConnection):
SERVICE_NAME = "reddit"
AUTHORIZE_URL = "https://www.reddit.com/api/v1/authorize"
TOKEN_URL = "https://www.reddit.com/api/v1/access_token"
USER_URL = "https://oauth.reddit.com/api/v1/me"
SCOPE: list[str] = ["identity"]

@classmethod
async def authorize_url(cls, user: User) -> str:
return f"{await super(cls, ConnectionReddit).authorize_url(user)}&response_type=code"

@classmethod
def exchange_code_req(cls, code: str, settings: dict[str, str]) -> tuple[str, dict]:
callback_url = quote(f"https://{Config.PUBLIC_HOST}/connections/reddit/callback", safe="")
kwargs = {
"auth": (settings["client_id"], settings["client_secret"]),
"headers": {"Accept": "application/json", "Content-Type": "application/x-www-form-urlencoded"},
"content": f"grant_type=authorization_code&code={code}&redirect_uri={callback_url}",
}

return cls.TOKEN_URL, kwargs


class ConnectionTwitch(BaseConnection):
SERVICE_NAME = "twitch"
AUTHORIZE_URL = "https://id.twitch.tv/oauth2/authorize"
TOKEN_URL = "https://id.twitch.tv/oauth2/token"
USER_URL = "https://api.twitch.tv/helix/users"
SCOPE: list[str] = ["channel_subscriptions", "channel_check_subscription", "channel:read:subscriptions"]

@classmethod
async def authorize_url(cls, user: User) -> str:
return f"{await super(cls, ConnectionTwitch).authorize_url(user)}&response_type=code"

@classmethod
def exchange_code_req(cls, code: str, settings: dict[str, str]) -> tuple[str, dict]:
callback_url = quote(f"https://{Config.PUBLIC_HOST}/connections/twitch/callback", safe="")
kwargs = {
"headers": {"Accept": "application/json", "Content-Type": "application/x-www-form-urlencoded"},
"content": f"grant_type=authorization_code&code={code}&redirect_uri={callback_url}"
f"&client_id={settings['client_id']}&client_secret={settings['client_secret']}",
}

return cls.TOKEN_URL, kwargs

@classmethod
def user_info_req(cls, access_token: str) -> tuple[str, dict]:
client_id = get_service_settings(cls.SERVICE_NAME, "client_id")["client_id"]
return cls.USER_URL, {"headers": {"Authorization": f"Bearer {access_token}", "Client-Id": client_id}}

@classmethod
async def get_user_info(cls, access_token: str) -> dict:
return (await super(cls, ConnectionTwitch).get_user_info(access_token))["data"][0]
1 change: 1 addition & 0 deletions yepcord/yepcord/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class ConfigConnectionBase(BaseModel):
class ConfigConnections(BaseModel):
github: ConfigConnectionBase = Field(default_factory=ConfigConnectionBase)
reddit: ConfigConnectionBase = Field(default_factory=ConfigConnectionBase)
twitch: ConfigConnectionBase = Field(default_factory=ConfigConnectionBase)


class ConfigModel(BaseModel):
Expand Down

0 comments on commit d09bb09

Please sign in to comment.