Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_initialize_sso_missing_client_id(self, set_api_key):

message = str(ex)

assert "project_id" in message
assert "client_id" in message
assert "api_key" not in message

def test_initialize_sso_missing_api_key_and_project_id(self):
Expand All @@ -52,7 +52,7 @@ def test_initialize_sso_missing_api_key_and_project_id(self):

message = str(ex)

assert all(setting in message for setting in ("api_key", "project_id",))
assert all(setting in message for setting in ("api_key", "client_id",))

def test_initialize_audit_trail_missing_api_key(self):
with pytest.raises(ConfigurationException) as ex:
Expand Down
75 changes: 60 additions & 15 deletions tests/test_sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,17 @@


class TestSSO(object):
@pytest.fixture(autouse=True)
def setup(self, set_api_key_and_project_id):
@pytest.fixture
def setup_with_client_id(self, set_api_key_and_client_id):
self.provider = ConnectionType.GoogleOAuth
self.customer_domain = "workos.com"
self.redirect_uri = "https://localhost/auth/callback"
self.state = json.dumps({"things": "with_stuff",})

self.sso = SSO()

@pytest.fixture
def setup_with_project_id(self, set_api_key_and_project_id):
self.provider = ConnectionType.GoogleOAuth
self.customer_domain = "workos.com"
self.redirect_uri = "https://localhost/auth/callback"
Expand Down Expand Up @@ -91,22 +100,26 @@ def mock_connections(self):
}

def test_authorization_url_throws_value_error_with_missing_domain_and_provider(
self,
self, setup_with_client_id
):
with pytest.raises(ValueError, match=r"Incomplete arguments.*"):
self.sso.get_authorization_url(
redirect_uri=self.redirect_uri, state=self.state
)

def test_authorization_url_throws_value_error_with_incorrect_provider_type(self):
def test_authorization_url_throws_value_error_with_incorrect_provider_type(
self, setup_with_client_id
):
with pytest.raises(
ValueError, match="'provider' must be of type ConnectionType"
):
self.sso.get_authorization_url(
provider="foo", redirect_uri=self.redirect_uri, state=self.state
)

def test_authorization_url_has_expected_query_params_with_provider(self):
def test_authorization_url_has_expected_query_params_with_provider(
self, setup_with_client_id
):
authorization_url = self.sso.get_authorization_url(
provider=self.provider, redirect_uri=self.redirect_uri, state=self.state
)
Expand All @@ -115,13 +128,15 @@ def test_authorization_url_has_expected_query_params_with_provider(self):

assert dict(parse_qsl(parsed_url.query)) == {
"provider": str(self.provider.value),
"client_id": workos.project_id,
"client_id": workos.client_id,
"redirect_uri": self.redirect_uri,
"response_type": RESPONSE_TYPE_CODE,
"state": self.state,
}

def test_authorization_url_has_expected_query_params_with_domain(self):
def test_authorization_url_has_expected_query_params_with_domain(
self, setup_with_client_id
):
authorization_url = self.sso.get_authorization_url(
domain=self.customer_domain,
redirect_uri=self.redirect_uri,
Expand All @@ -132,13 +147,15 @@ def test_authorization_url_has_expected_query_params_with_domain(self):

assert dict(parse_qsl(parsed_url.query)) == {
"domain": self.customer_domain,
"client_id": workos.project_id,
"client_id": workos.client_id,
"redirect_uri": self.redirect_uri,
"response_type": RESPONSE_TYPE_CODE,
"state": self.state,
}

def test_authorization_url_has_expected_query_params_with_domain_and_provider(self):
def test_authorization_url_has_expected_query_params_with_domain_and_provider(
self, setup_with_client_id
):
authorization_url = self.sso.get_authorization_url(
domain=self.customer_domain,
provider=self.provider,
Expand All @@ -151,14 +168,36 @@ def test_authorization_url_has_expected_query_params_with_domain_and_provider(se
assert dict(parse_qsl(parsed_url.query)) == {
"domain": self.customer_domain,
"provider": str(self.provider.value),
"client_id": workos.project_id,
"client_id": workos.client_id,
"redirect_uri": self.redirect_uri,
"response_type": RESPONSE_TYPE_CODE,
"state": self.state,
}

def test_authorization_url_supports_project_id_with_deprecation_warning(
self, setup_with_project_id
):
with pytest.deprecated_call():
authorization_url = self.sso.get_authorization_url(
domain=self.customer_domain,
provider=self.provider,
redirect_uri=self.redirect_uri,
state=self.state,
)

parsed_url = urlparse(authorization_url)

assert dict(parse_qsl(parsed_url.query)) == {
"domain": self.customer_domain,
"provider": str(self.provider.value),
"client_id": workos.project_id,
"redirect_uri": self.redirect_uri,
"response_type": RESPONSE_TYPE_CODE,
"state": self.state,
}

def test_get_profile_returns_expected_workosprofile_object(
self, mock_profile, mock_request_method
self, setup_with_client_id, mock_profile, mock_request_method
):
response_dict = {
"profile": {
Expand All @@ -185,7 +224,9 @@ def test_get_profile_returns_expected_workosprofile_object(

assert profile.to_dict() == mock_profile

def test_create_connection(self, mock_request_method, mock_connection):
def test_create_connection(
self, setup_with_client_id, mock_request_method, mock_connection
):
response_dict = {
"object": "connection",
"id": mock_connection["id"],
Expand All @@ -208,7 +249,9 @@ def test_create_connection(self, mock_request_method, mock_connection):
connection = self.sso.create_connection("draft_conn_id")
assert connection == response_dict

def test_get_connection(self, mock_connection, mock_request_method):
def test_get_connection(
self, setup_with_client_id, mock_connection, mock_request_method
):
mock_response = Response()
mock_response.status_code = 200
mock_response.response_dict = mock_connection
Expand All @@ -217,7 +260,9 @@ def test_get_connection(self, mock_connection, mock_request_method):
assert response.status_code == 200
assert response.response_dict == mock_connection

def test_list_connections(self, mock_connections, mock_request_method):
def test_list_connections(
self, setup_with_client_id, mock_connections, mock_request_method
):
mock_response = Response()
mock_response.status_code = 200
mock_response.response_dict = mock_connections
Expand All @@ -226,7 +271,7 @@ def test_list_connections(self, mock_connections, mock_request_method):
assert response.status_code == 200
assert response.response_dict == mock_connections

def test_delete_connection(self, mock_request_method):
def test_delete_connection(self, setup_with_client_id, mock_request_method):
mock_response = Response()
mock_response.status_code = 200
mock_request_method("delete", mock_response, 200)
Expand Down
2 changes: 1 addition & 1 deletion workos/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

__package_url__ = "https://github.com/workos-inc/workos-python"

__version__ = "0.8.3"
__version__ = "0.8.4"

__author__ = "WorkOS"

Expand Down
19 changes: 15 additions & 4 deletions workos/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
DIRECTORY_SYNC_MODULE: ["api_key",],
PASSWORDLESS_MODULE: ["api_key",],
PORTAL_MODULE: ["api_key",],
SSO_MODULE: ["api_key", "project_id",],
SSO_MODULE: ["api_key", "client_id",],
}


Expand All @@ -23,9 +23,20 @@ def decorator(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
missing_settings = []
for setting in REQUIRED_SETTINGS_FOR_MODULE[module_name]:
if not getattr(workos, setting, None):
missing_settings.append(setting)

# Adding this to accept both client_id and project_id
# can remove once project_id is deprecated
if module_name == SSO_MODULE:
if not getattr(workos, "api_key", None):
missing_settings.append("api_key")
if not getattr(workos, "client_id", None) and not getattr(
workos, "project_id", None
):
missing_settings.append("client_id")
else:
for setting in REQUIRED_SETTINGS_FOR_MODULE[module_name]:
if not getattr(workos, setting, None):
missing_settings.append(setting)

if missing_settings:
raise ConfigurationException(
Expand Down