Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔒 Generic Oauth installer #1150

Merged
merged 5 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions next/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,39 @@ model VerificationToken {
@@unique([identifier, token])
}

model OAuthCredentials {
id String @id @default(cuid())
installation_id String
provider String
token_type String
access_token String
scope String?
Copy link
Contributor

@asim-shrestha asim-shrestha Jul 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this typically work. As in, do all oauth systems typically use a string to define scope? How will we deal with scopes that require a list of strings

data Json

create_date DateTime @default(now())
update_date DateTime? @updatedAt
delete_date DateTime?

@@unique([installation_id])
@@map("oauth_credentials")
}

model OAuthInstallation {
id String @id @default(cuid())
user_id String
Comment on lines +120 to +121
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to be more specific with id? is it an account identifier?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope, it maps directly to the user table :)

organization_id String?
provider String
state String

create_date DateTime @default(now())
update_date DateTime? @updatedAt
delete_date DateTime?

@@unique([user_id, organization_id, provider])
@@index([state])
@@map("oauth_installation")
}

model Agent {
id String @id @default(cuid())
userId String
Expand Down
117 changes: 101 additions & 16 deletions platform/poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion platform/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ httptools = "^0.5.0"
sentry-sdk = "^1.28.1"
loguru = "^0.7.0"
aiokafka = "^0.8.1"
requests = "2.28.0"
requests = "^2.31.0"
langchain = "0.0.218"
openai = "^0.27.8"
wikipedia = "^1.4.0"
Expand All @@ -47,6 +47,7 @@ aws-secretsmanager-caching = "^1.1.1.5"
botocore = "^1.29.153"
stripe = "^5.4.0"
tabula-py = "^2.7.0"
slack-sdk = "^3.21.3"

[tool.poetry.dev-dependencies]
autopep8 = "^2.0.2"
Expand Down
4 changes: 4 additions & 0 deletions platform/reworkd_platform/db/crud/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import TypeVar

from sqlalchemy.ext.asyncio import AsyncSession

T = TypeVar("T", bound="BaseCrud")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this do?



class BaseCrud:
def __init__(self, session: AsyncSession):
Expand Down
37 changes: 37 additions & 0 deletions platform/reworkd_platform/db/crud/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import secrets
from typing import Optional

from fastapi import Depends
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from reworkd_platform.db.crud.base import BaseCrud
from reworkd_platform.db.dependencies import get_db_session
from reworkd_platform.db.models.auth import OauthInstallation
from reworkd_platform.schemas import UserBase


class OAuthCrud(BaseCrud):
@classmethod
async def inject(
cls,
session: AsyncSession = Depends(get_db_session),
) -> "OAuthCrud":
return cls(session)

async def create_installation(
self, user: UserBase, provider: str
) -> OauthInstallation:
return await OauthInstallation(
user_id=user.id,
organization_id=user.organization_id,
provider=provider,
state=secrets.token_hex(16),
).save(self.session)

async def get_installation_by_state(
self, state: str
) -> Optional[OauthInstallation]:
query = select(OauthInstallation).filter(OauthInstallation.state == state)

return (await self.session.execute(query)).scalar_one_or_none()
22 changes: 21 additions & 1 deletion platform/reworkd_platform/db/models/auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sqlalchemy import String
from sqlalchemy import String, JSON
from sqlalchemy.orm import mapped_column

from reworkd_platform.db.base import TrackedModel
Expand All @@ -17,3 +17,23 @@ class OrganizationUser(TrackedModel):
user_id = mapped_column(String, nullable=False)
organization_id = mapped_column(String, nullable=False)
role = mapped_column(String, nullable=False, default="member")


class OauthCredentials(TrackedModel):
__tablename__ = "oauth_credentials"

installation_id = mapped_column(String, nullable=False)
provider = mapped_column(String, nullable=False)
token_type = mapped_column(String, nullable=False)
access_token = mapped_column(String, nullable=False)
scope = mapped_column(String, nullable=True)
data = mapped_column(JSON, nullable=False)


class OauthInstallation(TrackedModel):
__tablename__ = "oauth_installation"

user_id = mapped_column(String, nullable=False)
organization_id = mapped_column(String, nullable=True)
provider = mapped_column(String, nullable=False)
state = mapped_column(String, nullable=False)
81 changes: 81 additions & 0 deletions platform/reworkd_platform/services/oauth_installers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# from slack.web import WebClient
from abc import ABC, abstractmethod
from typing import TypeVar

from fastapi import Depends, Path
from slack_sdk import WebClient
from slack_sdk.oauth import AuthorizeUrlGenerator

from reworkd_platform.db.crud.oauth import OAuthCrud
from reworkd_platform.db.models.auth import OauthCredentials
from reworkd_platform.schemas import UserBase
from reworkd_platform.settings import Settings, settings as platform_settings
from reworkd_platform.web.api.http_responses import forbidden

T = TypeVar("T", bound="OAuthInstaller")


class OAuthInstaller(ABC):
def __init__(self, crud: OAuthCrud, settings: Settings):
self.crud = crud
self.settings = settings

@abstractmethod
async def install(self, user: UserBase) -> str:
raise NotImplementedError()

@abstractmethod
async def install_callback(self, code: str, state: str) -> None:
raise NotImplementedError()


class SlackInstaller(OAuthInstaller):
PROVIDER = "slack"

async def install(self, user: UserBase) -> str:
installation = await self.crud.create_installation(user, self.PROVIDER)

return AuthorizeUrlGenerator(
client_id=self.settings.slack_client_id,
redirect_uri=self.settings.slack_redirect_uri,
scopes=["chat:write"],
).generate(
state=installation.state,
)

async def install_callback(self, code: str, state: str) -> None:
installation = await self.crud.get_installation_by_state(state)
if not installation:
raise forbidden()

oauth_response = WebClient().oauth_v2_access(
client_id=self.settings.slack_client_id,
client_secret=self.settings.slack_client_secret,
code=code,
state=state,
)

# We should handle token rotation / refresh tokens eventually
# TODO: encode token
await OauthCredentials(
installation_id=installation.id,
provider="slack",
token_type=oauth_response["token_type"],
access_token=oauth_response["access_token"],
scope=oauth_response["scope"],
data=oauth_response.data,
).save(self.crud.session)


integrations = {
SlackInstaller.PROVIDER: SlackInstaller,
}


def installer_factory(
provider: str = Path(description="OAuth Provider"),
crud: OAuthCrud = Depends(OAuthCrud.inject),
) -> OAuthInstaller:
if provider in integrations:
return integrations[provider](crud, platform_settings)
raise NotImplementedError()
5 changes: 5 additions & 0 deletions platform/reworkd_platform/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ class Settings(BaseSettings):
ff_mock_mode_enabled: bool = False # Controls whether calls are mocked
max_loops: int = 25 # Maximum number of loops to run

# Settings for slack
slack_client_id: str = ""
slack_client_secret: str = ""
slack_redirect_uri: str = ""

Comment on lines +105 to +109
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be in the DB now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is for our app, we are installing our app into other peoples slack

@property
def kafka_consumer_group(self) -> str:
"""
Expand Down
15 changes: 15 additions & 0 deletions platform/reworkd_platform/tests/test_oauth_installers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest

from reworkd_platform.services.oauth_installers import installer_factory


def test_installer_factory(mocker):
crud = mocker.Mock()
installer_factory("slack", crud)


def test_integration_dne(mocker):
crud = mocker.Mock()

with pytest.raises(NotImplementedError):
installer_factory("asim", crud)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤨

6 changes: 4 additions & 2 deletions platform/reworkd_platform/tests/workflow/test_if_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@
],
)
async def test_if_condition_success(value_one, operator, value_two, expected_result):
workflow_id = "123"
block = IfCondition(
input=IfInput(value_one=value_one, operator=operator, value_two=value_two)
)
result = await block.run(curr.workflow_id)
result = await block.run(workflow_id)
assert result == IfOutput(result=expected_result)


Expand All @@ -59,8 +60,9 @@ async def test_if_condition_success(value_one, operator, value_two, expected_res
],
)
async def test_if_condition_errors(value_one, operator, value_two):
workflow_id = "123"
block = IfCondition(
input=IfInput(value_one=value_one, operator=operator, value_two=value_two)
)
with pytest.raises(ValueError):
await block.run(curr.workflow_id)
await block.run(workflow_id)
25 changes: 25 additions & 0 deletions platform/reworkd_platform/web/api/auth/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

from reworkd_platform.db.crud.organization import OrganizationCrud, OrganizationUsers
from reworkd_platform.schemas import UserBase
from reworkd_platform.services.oauth_installers import (
installer_factory,
OAuthInstaller,
)
from reworkd_platform.services.sockets import websockets
from reworkd_platform.web.api.dependencies import get_current_user

Expand Down Expand Up @@ -44,3 +48,24 @@ async def pusher_authentication(
user: UserBase = Depends(get_current_user),
) -> Dict[str, str]:
return websockets.authenticate(user, channel_name, socket_id)


@router.get("/{provider}")
async def oauth_install(
user: UserBase = Depends(get_current_user),
installer: OAuthInstaller = Depends(installer_factory),
) -> str:
"""Install an OAuth App"""
url = await installer.install(user)
print(url)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: print

return url


@router.get("/{provider}/callback")
async def oauth_callback(
code: str,
state: str,
installer: OAuthInstaller = Depends(installer_factory),
) -> None:
"""Callback for OAuth App"""
return await installer.install_callback(code, state)
2 changes: 2 additions & 0 deletions scripts/prepare-sync.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
cd "$(dirname "$0")" || exit 1
git reset --hard

git fetch origin

git checkout main
Expand Down
Loading