-
-
Notifications
You must be signed in to change notification settings - Fork 9.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🔒 Generic Oauth installer #1150
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -99,6 +99,39 @@ model VerificationToken { | |
@@unique([identifier, token]) | ||
} | ||
|
||
model OAuthCredentials { | ||
id String @id @default(cuid()) | ||
installation_id String | ||
provider String | ||
token_type String | ||
access_token String | ||
scope String? | ||
data Json | ||
|
||
create_date DateTime @default(now()) | ||
update_date DateTime? @updatedAt | ||
delete_date DateTime? | ||
|
||
@@unique([installation_id]) | ||
@@map("oauth_credentials") | ||
} | ||
|
||
model OAuthInstallation { | ||
id String @id @default(cuid()) | ||
user_id String | ||
Comment on lines
+120
to
+121
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need to be more specific with id? is it an account identifier? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nope, it maps directly to the user table :) |
||
organization_id String? | ||
provider String | ||
state String | ||
|
||
create_date DateTime @default(now()) | ||
update_date DateTime? @updatedAt | ||
delete_date DateTime? | ||
|
||
@@unique([user_id, organization_id, provider]) | ||
@@index([state]) | ||
@@map("oauth_installation") | ||
} | ||
|
||
model Agent { | ||
id String @id @default(cuid()) | ||
userId String | ||
|
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,9 @@ | ||
from typing import TypeVar | ||
|
||
from sqlalchemy.ext.asyncio import AsyncSession | ||
|
||
T = TypeVar("T", bound="BaseCrud") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does this do? |
||
|
||
|
||
class BaseCrud: | ||
def __init__(self, session: AsyncSession): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import secrets | ||
from typing import Optional | ||
|
||
from fastapi import Depends | ||
from sqlalchemy import select | ||
from sqlalchemy.ext.asyncio import AsyncSession | ||
|
||
from reworkd_platform.db.crud.base import BaseCrud | ||
from reworkd_platform.db.dependencies import get_db_session | ||
from reworkd_platform.db.models.auth import OauthInstallation | ||
from reworkd_platform.schemas import UserBase | ||
|
||
|
||
class OAuthCrud(BaseCrud): | ||
@classmethod | ||
async def inject( | ||
cls, | ||
session: AsyncSession = Depends(get_db_session), | ||
) -> "OAuthCrud": | ||
return cls(session) | ||
|
||
async def create_installation( | ||
self, user: UserBase, provider: str | ||
) -> OauthInstallation: | ||
return await OauthInstallation( | ||
user_id=user.id, | ||
organization_id=user.organization_id, | ||
provider=provider, | ||
state=secrets.token_hex(16), | ||
).save(self.session) | ||
|
||
async def get_installation_by_state( | ||
self, state: str | ||
) -> Optional[OauthInstallation]: | ||
query = select(OauthInstallation).filter(OauthInstallation.state == state) | ||
|
||
return (await self.session.execute(query)).scalar_one_or_none() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# from slack.web import WebClient | ||
from abc import ABC, abstractmethod | ||
from typing import TypeVar | ||
|
||
from fastapi import Depends, Path | ||
from slack_sdk import WebClient | ||
from slack_sdk.oauth import AuthorizeUrlGenerator | ||
|
||
from reworkd_platform.db.crud.oauth import OAuthCrud | ||
from reworkd_platform.db.models.auth import OauthCredentials | ||
from reworkd_platform.schemas import UserBase | ||
from reworkd_platform.settings import Settings, settings as platform_settings | ||
from reworkd_platform.web.api.http_responses import forbidden | ||
|
||
T = TypeVar("T", bound="OAuthInstaller") | ||
|
||
|
||
class OAuthInstaller(ABC): | ||
def __init__(self, crud: OAuthCrud, settings: Settings): | ||
self.crud = crud | ||
self.settings = settings | ||
|
||
@abstractmethod | ||
async def install(self, user: UserBase) -> str: | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
async def install_callback(self, code: str, state: str) -> None: | ||
raise NotImplementedError() | ||
|
||
|
||
class SlackInstaller(OAuthInstaller): | ||
PROVIDER = "slack" | ||
|
||
async def install(self, user: UserBase) -> str: | ||
installation = await self.crud.create_installation(user, self.PROVIDER) | ||
|
||
return AuthorizeUrlGenerator( | ||
client_id=self.settings.slack_client_id, | ||
redirect_uri=self.settings.slack_redirect_uri, | ||
scopes=["chat:write"], | ||
).generate( | ||
state=installation.state, | ||
) | ||
|
||
async def install_callback(self, code: str, state: str) -> None: | ||
installation = await self.crud.get_installation_by_state(state) | ||
if not installation: | ||
raise forbidden() | ||
|
||
oauth_response = WebClient().oauth_v2_access( | ||
client_id=self.settings.slack_client_id, | ||
client_secret=self.settings.slack_client_secret, | ||
code=code, | ||
state=state, | ||
) | ||
|
||
# We should handle token rotation / refresh tokens eventually | ||
# TODO: encode token | ||
await OauthCredentials( | ||
installation_id=installation.id, | ||
provider="slack", | ||
token_type=oauth_response["token_type"], | ||
access_token=oauth_response["access_token"], | ||
scope=oauth_response["scope"], | ||
data=oauth_response.data, | ||
).save(self.crud.session) | ||
|
||
|
||
integrations = { | ||
SlackInstaller.PROVIDER: SlackInstaller, | ||
} | ||
|
||
|
||
def installer_factory( | ||
provider: str = Path(description="OAuth Provider"), | ||
crud: OAuthCrud = Depends(OAuthCrud.inject), | ||
) -> OAuthInstaller: | ||
if provider in integrations: | ||
return integrations[provider](crud, platform_settings) | ||
raise NotImplementedError() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -102,6 +102,11 @@ class Settings(BaseSettings): | |
ff_mock_mode_enabled: bool = False # Controls whether calls are mocked | ||
max_loops: int = 25 # Maximum number of loops to run | ||
|
||
# Settings for slack | ||
slack_client_id: str = "" | ||
slack_client_secret: str = "" | ||
slack_redirect_uri: str = "" | ||
|
||
Comment on lines
+105
to
+109
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this be in the DB now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No this is for our app, we are installing our app into other peoples slack |
||
@property | ||
def kafka_consumer_group(self) -> str: | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import pytest | ||
|
||
from reworkd_platform.services.oauth_installers import installer_factory | ||
|
||
|
||
def test_installer_factory(mocker): | ||
crud = mocker.Mock() | ||
installer_factory("slack", crud) | ||
|
||
|
||
def test_integration_dne(mocker): | ||
crud = mocker.Mock() | ||
|
||
with pytest.raises(NotImplementedError): | ||
installer_factory("asim", crud) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🤨 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,10 @@ | |
|
||
from reworkd_platform.db.crud.organization import OrganizationCrud, OrganizationUsers | ||
from reworkd_platform.schemas import UserBase | ||
from reworkd_platform.services.oauth_installers import ( | ||
installer_factory, | ||
OAuthInstaller, | ||
) | ||
from reworkd_platform.services.sockets import websockets | ||
from reworkd_platform.web.api.dependencies import get_current_user | ||
|
||
|
@@ -44,3 +48,24 @@ async def pusher_authentication( | |
user: UserBase = Depends(get_current_user), | ||
) -> Dict[str, str]: | ||
return websockets.authenticate(user, channel_name, socket_id) | ||
|
||
|
||
@router.get("/{provider}") | ||
async def oauth_install( | ||
user: UserBase = Depends(get_current_user), | ||
installer: OAuthInstaller = Depends(installer_factory), | ||
) -> str: | ||
"""Install an OAuth App""" | ||
url = await installer.install(user) | ||
print(url) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: print |
||
return url | ||
|
||
|
||
@router.get("/{provider}/callback") | ||
async def oauth_callback( | ||
code: str, | ||
state: str, | ||
installer: OAuthInstaller = Depends(installer_factory), | ||
) -> None: | ||
"""Callback for OAuth App""" | ||
return await installer.install_callback(code, state) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
cd "$(dirname "$0")" || exit 1 | ||
git reset --hard | ||
|
||
git fetch origin | ||
|
||
git checkout main | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this typically work. As in, do all oauth systems typically use a string to define scope? How will we deal with scopes that require a list of strings