diff --git a/tests/test_client.py b/tests/test_client.py index 62a4d01b..4523a5aa 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -43,7 +43,7 @@ def test_initialize_sso_missing_client_id(self, set_api_key): message = str(ex) - assert "project_id" in message + assert "client_id" in message assert "api_key" not in message def test_initialize_sso_missing_api_key_and_project_id(self): @@ -52,7 +52,7 @@ def test_initialize_sso_missing_api_key_and_project_id(self): message = str(ex) - assert all(setting in message for setting in ("api_key", "project_id",)) + assert all(setting in message for setting in ("api_key", "client_id",)) def test_initialize_audit_trail_missing_api_key(self): with pytest.raises(ConfigurationException) as ex: diff --git a/tests/test_sso.py b/tests/test_sso.py index 40e3f0e5..6e02f2d5 100644 --- a/tests/test_sso.py +++ b/tests/test_sso.py @@ -11,8 +11,17 @@ class TestSSO(object): - @pytest.fixture(autouse=True) - def setup(self, set_api_key_and_project_id): + @pytest.fixture + def setup_with_client_id(self, set_api_key_and_client_id): + self.provider = ConnectionType.GoogleOAuth + self.customer_domain = "workos.com" + self.redirect_uri = "https://localhost/auth/callback" + self.state = json.dumps({"things": "with_stuff",}) + + self.sso = SSO() + + @pytest.fixture + def setup_with_project_id(self, set_api_key_and_project_id): self.provider = ConnectionType.GoogleOAuth self.customer_domain = "workos.com" self.redirect_uri = "https://localhost/auth/callback" @@ -91,14 +100,16 @@ def mock_connections(self): } def test_authorization_url_throws_value_error_with_missing_domain_and_provider( - self, + self, setup_with_client_id ): with pytest.raises(ValueError, match=r"Incomplete arguments.*"): self.sso.get_authorization_url( redirect_uri=self.redirect_uri, state=self.state ) - def test_authorization_url_throws_value_error_with_incorrect_provider_type(self): + def test_authorization_url_throws_value_error_with_incorrect_provider_type( + self, setup_with_client_id + ): with pytest.raises( ValueError, match="'provider' must be of type ConnectionType" ): @@ -106,7 +117,9 @@ def test_authorization_url_throws_value_error_with_incorrect_provider_type(self) provider="foo", redirect_uri=self.redirect_uri, state=self.state ) - def test_authorization_url_has_expected_query_params_with_provider(self): + def test_authorization_url_has_expected_query_params_with_provider( + self, setup_with_client_id + ): authorization_url = self.sso.get_authorization_url( provider=self.provider, redirect_uri=self.redirect_uri, state=self.state ) @@ -115,13 +128,15 @@ def test_authorization_url_has_expected_query_params_with_provider(self): assert dict(parse_qsl(parsed_url.query)) == { "provider": str(self.provider.value), - "client_id": workos.project_id, + "client_id": workos.client_id, "redirect_uri": self.redirect_uri, "response_type": RESPONSE_TYPE_CODE, "state": self.state, } - def test_authorization_url_has_expected_query_params_with_domain(self): + def test_authorization_url_has_expected_query_params_with_domain( + self, setup_with_client_id + ): authorization_url = self.sso.get_authorization_url( domain=self.customer_domain, redirect_uri=self.redirect_uri, @@ -132,13 +147,15 @@ def test_authorization_url_has_expected_query_params_with_domain(self): assert dict(parse_qsl(parsed_url.query)) == { "domain": self.customer_domain, - "client_id": workos.project_id, + "client_id": workos.client_id, "redirect_uri": self.redirect_uri, "response_type": RESPONSE_TYPE_CODE, "state": self.state, } - def test_authorization_url_has_expected_query_params_with_domain_and_provider(self): + def test_authorization_url_has_expected_query_params_with_domain_and_provider( + self, setup_with_client_id + ): authorization_url = self.sso.get_authorization_url( domain=self.customer_domain, provider=self.provider, @@ -151,14 +168,36 @@ def test_authorization_url_has_expected_query_params_with_domain_and_provider(se assert dict(parse_qsl(parsed_url.query)) == { "domain": self.customer_domain, "provider": str(self.provider.value), - "client_id": workos.project_id, + "client_id": workos.client_id, "redirect_uri": self.redirect_uri, "response_type": RESPONSE_TYPE_CODE, "state": self.state, } + def test_authorization_url_supports_project_id_with_deprecation_warning( + self, setup_with_project_id + ): + with pytest.deprecated_call(): + authorization_url = self.sso.get_authorization_url( + domain=self.customer_domain, + provider=self.provider, + redirect_uri=self.redirect_uri, + state=self.state, + ) + + parsed_url = urlparse(authorization_url) + + assert dict(parse_qsl(parsed_url.query)) == { + "domain": self.customer_domain, + "provider": str(self.provider.value), + "client_id": workos.project_id, + "redirect_uri": self.redirect_uri, + "response_type": RESPONSE_TYPE_CODE, + "state": self.state, + } + def test_get_profile_returns_expected_workosprofile_object( - self, mock_profile, mock_request_method + self, setup_with_client_id, mock_profile, mock_request_method ): response_dict = { "profile": { @@ -185,7 +224,9 @@ def test_get_profile_returns_expected_workosprofile_object( assert profile.to_dict() == mock_profile - def test_create_connection(self, mock_request_method, mock_connection): + def test_create_connection( + self, setup_with_client_id, mock_request_method, mock_connection + ): response_dict = { "object": "connection", "id": mock_connection["id"], @@ -208,7 +249,9 @@ def test_create_connection(self, mock_request_method, mock_connection): connection = self.sso.create_connection("draft_conn_id") assert connection == response_dict - def test_get_connection(self, mock_connection, mock_request_method): + def test_get_connection( + self, setup_with_client_id, mock_connection, mock_request_method + ): mock_response = Response() mock_response.status_code = 200 mock_response.response_dict = mock_connection @@ -217,7 +260,9 @@ def test_get_connection(self, mock_connection, mock_request_method): assert response.status_code == 200 assert response.response_dict == mock_connection - def test_list_connections(self, mock_connections, mock_request_method): + def test_list_connections( + self, setup_with_client_id, mock_connections, mock_request_method + ): mock_response = Response() mock_response.status_code = 200 mock_response.response_dict = mock_connections @@ -226,7 +271,7 @@ def test_list_connections(self, mock_connections, mock_request_method): assert response.status_code == 200 assert response.response_dict == mock_connections - def test_delete_connection(self, mock_request_method): + def test_delete_connection(self, setup_with_client_id, mock_request_method): mock_response = Response() mock_response.status_code = 200 mock_request_method("delete", mock_response, 200) diff --git a/workos/__about__.py b/workos/__about__.py index f53028ed..2781df44 100644 --- a/workos/__about__.py +++ b/workos/__about__.py @@ -12,7 +12,7 @@ __package_url__ = "https://github.com/workos-inc/workos-python" -__version__ = "0.8.3" +__version__ = "0.8.4" __author__ = "WorkOS" diff --git a/workos/utils/validation.py b/workos/utils/validation.py index 7f4699e5..139be8f4 100644 --- a/workos/utils/validation.py +++ b/workos/utils/validation.py @@ -14,7 +14,7 @@ DIRECTORY_SYNC_MODULE: ["api_key",], PASSWORDLESS_MODULE: ["api_key",], PORTAL_MODULE: ["api_key",], - SSO_MODULE: ["api_key", "project_id",], + SSO_MODULE: ["api_key", "client_id",], } @@ -23,9 +23,20 @@ def decorator(fn): @wraps(fn) def wrapper(*args, **kwargs): missing_settings = [] - for setting in REQUIRED_SETTINGS_FOR_MODULE[module_name]: - if not getattr(workos, setting, None): - missing_settings.append(setting) + + # Adding this to accept both client_id and project_id + # can remove once project_id is deprecated + if module_name == SSO_MODULE: + if not getattr(workos, "api_key", None): + missing_settings.append("api_key") + if not getattr(workos, "client_id", None) and not getattr( + workos, "project_id", None + ): + missing_settings.append("client_id") + else: + for setting in REQUIRED_SETTINGS_FOR_MODULE[module_name]: + if not getattr(workos, setting, None): + missing_settings.append(setting) if missing_settings: raise ConfigurationException(