diff --git a/requirements.txt b/requirements.txt index 3b962207..8d364e41 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ grpcio>=1.42.0 packaging protobuf>=3.13.0,<5.0.0 aiohttp<4 +pyjwt==2.8.0 diff --git a/test-requirements.txt b/test-requirements.txt index 43c61a10..21da70d3 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -46,4 +46,5 @@ pylint-protobuf cython freezegun==1.2.2 pytest-cov +yandexcloud -e . diff --git a/tests/aio/test_credentials.py b/tests/aio/test_credentials.py new file mode 100644 index 00000000..10878a40 --- /dev/null +++ b/tests/aio/test_credentials.py @@ -0,0 +1,58 @@ +import pytest +import time +import grpc +import threading + +import tests.auth.test_credentials +import ydb.aio.iam + + +class TestServiceAccountCredentials(ydb.aio.iam.ServiceAccountCredentials): + def _channel_factory(self): + return grpc.aio.insecure_channel(self._iam_endpoint) + + def get_expire_time(self): + return self._expires_in - time.time() + + +class TestNebiusServiceAccountCredentials(ydb.aio.iam.NebiusServiceAccountCredentials): + def get_expire_time(self): + return self._expires_in - time.time() + + +@pytest.mark.asyncio +async def test_yandex_service_account_credentials(): + server = tests.auth.test_credentials.IamTokenServiceTestServer() + credentials = TestServiceAccountCredentials( + tests.auth.test_credentials.SERVICE_ACCOUNT_ID, + tests.auth.test_credentials.ACCESS_KEY_ID, + tests.auth.test_credentials.PRIVATE_KEY, + server.get_endpoint(), + ) + t = (await credentials.auth_metadata())[0][1] + assert t == "test_token" + assert credentials.get_expire_time() <= 42 + server.stop() + + +@pytest.mark.asyncio +async def test_nebius_service_account_credentials(): + server = tests.auth.test_credentials.NebiusTokenServiceForTest() + + def serve(s): + s.handle_request() + + serve_thread = threading.Thread(target=serve, args=(server,)) + serve_thread.start() + + credentials = TestNebiusServiceAccountCredentials( + tests.auth.test_credentials.SERVICE_ACCOUNT_ID, + tests.auth.test_credentials.ACCESS_KEY_ID, + tests.auth.test_credentials.PRIVATE_KEY, + server.endpoint(), + ) + t = (await credentials.auth_metadata())[0][1] + assert t == "test_nebius_token" + assert credentials.get_expire_time() <= 42 + + serve_thread.join() diff --git a/tests/auth/__init__.py b/tests/auth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/auth/test_credentials.py b/tests/auth/test_credentials.py new file mode 100644 index 00000000..bd4c9809 --- /dev/null +++ b/tests/auth/test_credentials.py @@ -0,0 +1,149 @@ +import jwt +import concurrent.futures +import grpc +import time +import http.server +import urllib +import threading +import json + +import ydb.iam + +from yandex.cloud.iam.v1 import iam_token_service_pb2_grpc +from yandex.cloud.iam.v1 import iam_token_service_pb2 + +SERVICE_ACCOUNT_ID = "sa_id" +ACCESS_KEY_ID = "key_id" +PRIVATE_KEY = "-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC75/JS3rMcLJxv\nFgpOzF5+2gH+Yig3RE2MTl9uwC0BZKAv6foYr7xywQyWIK+W1cBhz8R4LfFmZo2j\nM0aCvdRmNBdW0EDSTnHLxCsFhoQWLVq+bI5f5jzkcoiioUtaEpADPqwgVULVtN/n\nnPJiZ6/dU30C3jmR6+LUgEntUtWt3eq3xQIn5lG3zC1klBY/HxtfH5Hu8xBvwRQT\nJnh3UpPLj8XwSmriDgdrhR7o6umWyVuGrMKlLHmeivlfzjYtfzO1MOIMG8t2/zxG\nR+xb4Vwks73sH1KruH/0/JMXU97npwpe+Um+uXhpldPygGErEia7abyZB2gMpXqr\nWYKMo02NAgMBAAECggEAO0BpC5OYw/4XN/optu4/r91bupTGHKNHlsIR2rDzoBhU\nYLd1evpTQJY6O07EP5pYZx9mUwUdtU4KRJeDGO/1/WJYp7HUdtxwirHpZP0lQn77\nuccuX/QQaHLrPekBgz4ONk+5ZBqukAfQgM7fKYOLk41jgpeDbM2Ggb6QUSsJISEp\nzrwpI/nNT/wn+Hvx4DxrzWU6wF+P8kl77UwPYlTA7GsT+T7eKGVH8xsxmK8pt6lg\nsvlBA5XosWBWUCGLgcBkAY5e4ZWbkdd183o+oMo78id6C+PQPE66PLDtHWfpRRmN\nm6XC03x6NVhnfvfozoWnmS4+e4qj4F/emCHvn0GMywKBgQDLXlj7YPFVXxZpUvg/\nrheVcCTGbNmQJ+4cZXx87huqwqKgkmtOyeWsRc7zYInYgraDrtCuDBCfP//ZzOh0\nLxepYLTPk5eNn/GT+VVrqsy35Ccr60g7Lp/bzb1WxyhcLbo0KX7/6jl0lP+VKtdv\nmto+4mbSBXSM1Y5BVVoVgJ3T/wKBgQDsiSvPRzVi5TTj13x67PFymTMx3HCe2WzH\nJUyepCmVhTm482zW95pv6raDr5CTO6OYpHtc5sTTRhVYEZoEYFTM9Vw8faBtluWG\nBjkRh4cIpoIARMn74YZKj0C/0vdX7SHdyBOU3bgRPHg08Hwu3xReqT1kEPSI/B2V\n4pe5fVrucwKBgQCNFgUxUA3dJjyMES18MDDYUZaRug4tfiYouRdmLGIxUxozv6CG\nZnbZzwxFt+GpvPUV4f+P33rgoCvFU+yoPctyjE6j+0aW0DFucPmb2kBwCu5J/856\nkFwCx3blbwFHAco+SdN7g2kcwgmV2MTg/lMOcU7XwUUcN0Obe7UlWbckzQKBgQDQ\nnXaXHL24GGFaZe4y2JFmujmNy1dEsoye44W9ERpf9h1fwsoGmmCKPp90az5+rIXw\nFXl8CUgk8lXW08db/r4r+ma8Lyx0GzcZyplAnaB5/6j+pazjSxfO4KOBy4Y89Tb+\nTP0AOcCi6ws13bgY+sUTa/5qKA4UVw+c5zlb7nRpgwKBgGXAXhenFw1666482iiN\ncHSgwc4ZHa1oL6aNJR1XWH+aboBSwR+feKHUPeT4jHgzRGo/aCNHD2FE5I8eBv33\nof1kWYjAO0YdzeKrW0rTwfvt9gGg+CS397aWu4cy+mTI+MNfBgeDAIVBeJOJXLlX\nhL8bFAuNNVrCOp79TNnNIsh7\n-----END PRIVATE KEY-----\n" # noqa: E501 +PUBLIC_KEY = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAu+fyUt6zHCycbxYKTsxe\nftoB/mIoN0RNjE5fbsAtAWSgL+n6GK+8csEMliCvltXAYc/EeC3xZmaNozNGgr3U\nZjQXVtBA0k5xy8QrBYaEFi1avmyOX+Y85HKIoqFLWhKQAz6sIFVC1bTf55zyYmev\n3VN9At45kevi1IBJ7VLVrd3qt8UCJ+ZRt8wtZJQWPx8bXx+R7vMQb8EUEyZ4d1KT\ny4/F8Epq4g4Ha4Ue6OrplslbhqzCpSx5nor5X842LX8ztTDiDBvLdv88RkfsW+Fc\nJLO97B9Sq7h/9PyTF1Pe56cKXvlJvrl4aZXT8oBhKxImu2m8mQdoDKV6q1mCjKNN\njQIDAQAB\n-----END PUBLIC KEY-----\n" # noqa: E501 + + +def test_metadata_credentials(): + credentials = ydb.iam.MetadataUrlCredentials() + raised = False + try: + credentials.auth_metadata() + except Exception: + raised = True + + assert raised + + +class IamTokenServiceForTest(iam_token_service_pb2_grpc.IamTokenServiceServicer): + def Create(self, request, context): + print("IAM token service request: {}".format(request)) + # Validate jwt: + decoded = jwt.decode( + request.jwt, key=PUBLIC_KEY, algorithms=["PS256"], audience="https://iam.api.cloud.yandex.net/iam/v1/tokens" + ) + assert decoded["iss"] == SERVICE_ACCOUNT_ID + assert decoded["aud"] == "https://iam.api.cloud.yandex.net/iam/v1/tokens" + assert abs(decoded["iat"] - time.time()) <= 60 + assert abs(decoded["exp"] - time.time()) <= 3600 + + response = iam_token_service_pb2.CreateIamTokenResponse(iam_token="test_token") + response.expires_at.seconds = int(time.time() + 42) + return response + + +class IamTokenServiceTestServer(object): + def __init__(self): + self.server = grpc.server(concurrent.futures.ThreadPoolExecutor(max_workers=2)) + iam_token_service_pb2_grpc.add_IamTokenServiceServicer_to_server(IamTokenServiceForTest(), self.server) + self.server.add_insecure_port(self.get_endpoint()) + self.server.start() + + def stop(self): + self.server.stop(1) + self.server.wait_for_termination() + + def get_endpoint(self): + return "localhost:54321" + + +class TestServiceAccountCredentials(ydb.iam.ServiceAccountCredentials): + def _channel_factory(self): + return grpc.insecure_channel(self._iam_endpoint) + + def get_expire_time(self): + return self._expires_in - time.time() + + +class TestNebiusServiceAccountCredentials(ydb.iam.NebiusServiceAccountCredentials): + def get_expire_time(self): + return self._expires_in - time.time() + + +class NebiusTokenServiceHandler(http.server.BaseHTTPRequestHandler): + def do_POST(self): + assert self.headers["Content-Type"] == "application/x-www-form-urlencoded" + assert self.path == "/token/exchange" + content_length = int(self.headers["Content-Length"]) + post_data = self.rfile.read(content_length).decode("utf8") + print("NebiusTokenServiceHandler.POST data: {}".format(post_data)) + parsed_request = urllib.parse.parse_qs(str(post_data)) + assert len(parsed_request["grant_type"]) == 1 + assert parsed_request["grant_type"][0] == "urn:ietf:params:oauth:grant-type:token-exchange" + + assert len(parsed_request["requested_token_type"]) == 1 + assert parsed_request["requested_token_type"][0] == "urn:ietf:params:oauth:token-type:access_token" + + assert len(parsed_request["subject_token_type"]) == 1 + assert parsed_request["subject_token_type"][0] == "urn:ietf:params:oauth:token-type:jwt" + + assert len(parsed_request["subject_token"]) == 1 + jwt_token = parsed_request["subject_token"][0] + decoded = jwt.decode( + jwt_token, key=PUBLIC_KEY, algorithms=["RS256"], audience="token-service.iam.new.nebiuscloud.net" + ) + assert decoded["iss"] == SERVICE_ACCOUNT_ID + assert decoded["sub"] == SERVICE_ACCOUNT_ID + assert decoded["aud"] == "token-service.iam.new.nebiuscloud.net" + assert abs(decoded["iat"] - time.time()) <= 60 + assert abs(decoded["exp"] - time.time()) <= 3600 + + response = { + "access_token": "test_nebius_token", + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "token_type": "Bearer", + "expires_in": 42, + } + + self.send_response(200) + self.send_header("Content-type", "application/json") + self.end_headers() + self.wfile.write(json.dumps(response).encode("utf8")) + + +class NebiusTokenServiceForTest(http.server.HTTPServer): + def __init__(self): + http.server.HTTPServer.__init__(self, ("localhost", 54322), NebiusTokenServiceHandler) + + def endpoint(self): + return "http://localhost:54322/token/exchange" + + +def test_yandex_service_account_credentials(): + server = IamTokenServiceTestServer() + credentials = TestServiceAccountCredentials(SERVICE_ACCOUNT_ID, ACCESS_KEY_ID, PRIVATE_KEY, server.get_endpoint()) + t = credentials.get_auth_token() + assert t == "test_token" + assert credentials.get_expire_time() <= 42 + server.stop() + + +def test_nebius_service_account_credentials(): + server = NebiusTokenServiceForTest() + + def serve(s): + s.handle_request() + + serve_thread = threading.Thread(target=serve, args=(server,)) + serve_thread.start() + + credentials = TestNebiusServiceAccountCredentials(SERVICE_ACCOUNT_ID, ACCESS_KEY_ID, PRIVATE_KEY, server.endpoint()) + t = credentials.get_auth_token() + assert t == "test_nebius_token" + assert credentials.get_expire_time() <= 42 + + serve_thread.join() diff --git a/tests/table/test_tx.py b/tests/table/test_tx.py index 750a13bf..067ba9bd 100644 --- a/tests/table/test_tx.py +++ b/tests/table/test_tx.py @@ -38,17 +38,6 @@ def test_tx_begin(driver_sync, database): tx.rollback() -def test_credentials(): - credentials = ydb.iam.MetadataUrlCredentials() - raised = False - try: - credentials.auth_metadata() - except Exception: - raised = True - - assert raised - - def test_tx_snapshot_ro(driver_sync, database): session = driver_sync.table_client.session().create() description = ( diff --git a/ydb/aio/iam.py b/ydb/aio/iam.py index eab8faff..40622f8a 100644 --- a/ydb/aio/iam.py +++ b/ydb/aio/iam.py @@ -5,15 +5,19 @@ import logging from ydb.iam import auth from .credentials import AbstractExpiringTokenCredentials +from ydb import issues logger = logging.getLogger(__name__) try: - from yandex.cloud.iam.v1 import iam_token_service_pb2_grpc - from yandex.cloud.iam.v1 import iam_token_service_pb2 import jwt except ImportError: jwt = None + +try: + from yandex.cloud.iam.v1 import iam_token_service_pb2_grpc + from yandex.cloud.iam.v1 import iam_token_service_pb2 +except ImportError: iam_token_service_pb2_grpc = None iam_token_service_pb2 = None @@ -55,6 +59,51 @@ async def _make_token_request(self): IamTokenCredentials = TokenServiceCredentials +class OAuth2JwtTokenExchangeCredentials(AbstractExpiringTokenCredentials, auth.BaseJWTCredentials): + def __init__( + self, + token_exchange_url, + account_id, + access_key_id, + private_key, + algorithm, + token_service_url, + subject=None, + ): + super(OAuth2JwtTokenExchangeCredentials, self).__init__() + auth.BaseJWTCredentials.__init__( + self, account_id, access_key_id, private_key, algorithm, token_service_url, subject + ) + assert aiohttp is not None, "Install aiohttp library to use OAuth 2.0 token exchange credentials provider" + self._token_exchange_url = token_exchange_url + + async def _make_token_request(self): + params = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": self._get_jwt(), + "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", + } + headers = {"Content-Type": "application/x-www-form-urlencoded"} + + timeout = aiohttp.ClientTimeout(total=2) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(self._token_exchange_url, data=params, headers=headers) as response: + if response.status == 403: + raise issues.Unauthenticated(await response.text()) + if response.status >= 500: + raise issues.Unavailable(await response.text()) + if response.status >= 400: + raise issues.BadRequest(await response.text()) + if response.status != 200: + raise issues.Error(await response.text()) + + response_json = await response.json() + access_token = response_json["access_token"] + expires_in = response_json["expires_in"] + return {"access_token": access_token, "expires_in": expires_in} + + class JWTIamCredentials(TokenServiceCredentials, auth.BaseJWTCredentials): def __init__( self, @@ -65,16 +114,39 @@ def __init__( iam_channel_credentials=None, ): TokenServiceCredentials.__init__(self, iam_endpoint, iam_channel_credentials) - auth.BaseJWTCredentials.__init__(self, account_id, access_key_id, private_key) + auth.BaseJWTCredentials.__init__( + self, + account_id, + access_key_id, + private_key, + auth.YANDEX_CLOUD_JWT_ALGORITHM, + auth.YANDEX_CLOUD_IAM_TOKEN_SERVICE_URL, + ) def _get_token_request(self): - return iam_token_service_pb2.CreateIamTokenRequest( - jwt=auth.get_jwt( - self._account_id, - self._access_key_id, - self._private_key, - self._jwt_expiration_timeout, - ) + return iam_token_service_pb2.CreateIamTokenRequest(jwt=self._get_jwt()) + + +class NebiusJWTIamCredentials(OAuth2JwtTokenExchangeCredentials): + def __init__( + self, + account_id, + access_key_id, + private_key, + token_exchange_url=None, + ): + url = token_exchange_url + if url is None: + url = auth.NEBIUS_CLOUD_IAM_TOKEN_EXCHANGE_URL + OAuth2JwtTokenExchangeCredentials.__init__( + self, + url, + account_id, + access_key_id, + private_key, + auth.NEBIUS_CLOUD_JWT_ALGORITHM, + auth.NEBIUS_CLOUD_IAM_TOKEN_SERVICE_AUDIENCE, + account_id, ) @@ -130,3 +202,20 @@ def __init__( iam_endpoint, iam_channel_credentials, ) + + +class NebiusServiceAccountCredentials(NebiusJWTIamCredentials): + def __init__( + self, + service_account_id, + access_key_id, + private_key, + iam_endpoint=None, + iam_channel_credentials=None, + ): + super(NebiusServiceAccountCredentials, self).__init__( + service_account_id, + access_key_id, + private_key, + iam_endpoint, + ) diff --git a/ydb/driver.py b/ydb/driver.py index 89109b9b..16bba151 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -38,6 +38,13 @@ def credentials_from_env_variables(tracer=None): return ydb.iam.ServiceAccountCredentials.from_file(service_account_key_file) + nebius_service_account_key_file = os.getenv("YDB_NEBIUS_SERVICE_ACCOUNT_KEY_FILE_CREDENTIALS") + if nebius_service_account_key_file is not None: + ctx.trace({"credentials.nebius_service_account_key_file": True}) + import ydb.iam + + return ydb.iam.NebiusServiceAccountCredentials.from_file(nebius_service_account_key_file) + anonymous_credetials = os.getenv("YDB_ANONYMOUS_CREDENTIALS", "0") == "1" if anonymous_credetials: ctx.trace({"credentials.anonymous": True}) diff --git a/ydb/iam/__init__.py b/ydb/iam/__init__.py index 7167efe1..cf835769 100644 --- a/ydb/iam/__init__.py +++ b/ydb/iam/__init__.py @@ -1,3 +1,4 @@ # -*- coding: utf-8 -*- from .auth import ServiceAccountCredentials # noqa +from .auth import NebiusServiceAccountCredentials # noqa from .auth import MetadataUrlCredentials # noqa diff --git a/ydb/iam/auth.py b/ydb/iam/auth.py index 82e7c9f6..852c0c28 100644 --- a/ydb/iam/auth.py +++ b/ydb/iam/auth.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -from ydb import credentials, tracing +from ydb import credentials, tracing, issues import grpc import time import abc @@ -8,11 +8,14 @@ import os try: - from yandex.cloud.iam.v1 import iam_token_service_pb2_grpc - from yandex.cloud.iam.v1 import iam_token_service_pb2 import jwt except ImportError: jwt = None + +try: + from yandex.cloud.iam.v1 import iam_token_service_pb2_grpc + from yandex.cloud.iam.v1 import iam_token_service_pb2 +except ImportError: iam_token_service_pb2_grpc = None iam_token_service_pb2 = None @@ -23,22 +26,32 @@ DEFAULT_METADATA_URL = "http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/token" +YANDEX_CLOUD_IAM_TOKEN_SERVICE_URL = "https://iam.api.cloud.yandex.net/iam/v1/tokens" +NEBIUS_CLOUD_IAM_TOKEN_SERVICE_AUDIENCE = "token-service.iam.new.nebiuscloud.net" +NEBIUS_CLOUD_IAM_TOKEN_EXCHANGE_URL = "https://auth.new.nebiuscloud.net/oauth2/token/exchange" +YANDEX_CLOUD_JWT_ALGORITHM = "PS256" +NEBIUS_CLOUD_JWT_ALGORITHM = "RS256" -def get_jwt(account_id, access_key_id, private_key, jwt_expiration_timeout): + +def get_jwt(account_id, access_key_id, private_key, jwt_expiration_timeout, algorithm, token_service_url, subject=None): + assert jwt is not None, "Install pyjwt library to use jwt tokens" now = time.time() now_utc = datetime.utcfromtimestamp(now) exp_utc = datetime.utcfromtimestamp(now + jwt_expiration_timeout) + payload = { + "iss": account_id, + "aud": token_service_url, + "iat": now_utc, + "exp": exp_utc, + } + if subject is not None: + payload["sub"] = subject return jwt.encode( key=private_key, - algorithm="PS256", - headers={"typ": "JWT", "alg": "PS256", "kid": access_key_id}, - payload={ - "iss": account_id, - "aud": "https://iam.api.cloud.yandex.net/iam/v1/tokens", - "iat": now_utc, - "exp": exp_utc, - }, + algorithm=algorithm, + headers={"typ": "JWT", "alg": algorithm, "kid": access_key_id}, + payload=payload, ) @@ -73,12 +86,15 @@ def _make_token_request(self): class BaseJWTCredentials(abc.ABC): - def __init__(self, account_id, access_key_id, private_key): + def __init__(self, account_id, access_key_id, private_key, algorithm, token_service_url, subject=None): self._account_id = account_id self._jwt_expiration_timeout = 60.0 * 60 self._token_expiration_timeout = 120 self._access_key_id = access_key_id self._private_key = private_key + self._algorithm = algorithm + self._token_service_url = token_service_url + self._subject = subject def set_token_expiration_timeout(self, value): self._token_expiration_timeout = value @@ -99,6 +115,64 @@ def from_file(cls, key_file, iam_endpoint=None, iam_channel_credentials=None): iam_channel_credentials=iam_channel_credentials, ) + def _get_jwt(self): + return get_jwt( + self._account_id, + self._access_key_id, + self._private_key, + self._jwt_expiration_timeout, + self._algorithm, + self._token_service_url, + self._subject, + ) + + +class OAuth2JwtTokenExchangeCredentials(credentials.AbstractExpiringTokenCredentials, BaseJWTCredentials): + def __init__( + self, + token_exchange_url, + account_id, + access_key_id, + private_key, + algorithm, + token_service_url, + subject=None, + tracer=None, + ): + BaseJWTCredentials.__init__(self, account_id, access_key_id, private_key, algorithm, token_service_url, subject) + super(OAuth2JwtTokenExchangeCredentials, self).__init__(tracer) + assert requests is not None, "Install requests library to use OAuth 2.0 token exchange credentials provider" + self._token_exchange_url = token_exchange_url + + def _process_response_status_code(self, response): + if response.status_code == 403: + raise issues.Unauthenticated(response.content) + if response.status_code >= 500: + raise issues.Unavailable(response.content) + if response.status_code >= 400: + raise issues.BadRequest(response.content) + if response.status_code != 200: + raise issues.Error(response.content) + + def _process_response(self, response): + self._process_response_status_code(response) + response_json = json.loads(response.content) + access_token = response_json["access_token"] + expires_in = response_json["expires_in"] + return {"access_token": access_token, "expires_in": expires_in} + + @tracing.with_trace() + def _make_token_request(self): + params = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": self._get_jwt(), + "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", + } + headers = {"Content-Type": "application/x-www-form-urlencoded"} + response = requests.post(self._token_exchange_url, data=params, headers=headers) + return self._process_response(response) + class JWTIamCredentials(TokenServiceCredentials, BaseJWTCredentials): def __init__( @@ -110,16 +184,34 @@ def __init__( iam_channel_credentials=None, ): TokenServiceCredentials.__init__(self, iam_endpoint, iam_channel_credentials) - BaseJWTCredentials.__init__(self, account_id, access_key_id, private_key) + BaseJWTCredentials.__init__( + self, account_id, access_key_id, private_key, YANDEX_CLOUD_JWT_ALGORITHM, YANDEX_CLOUD_IAM_TOKEN_SERVICE_URL + ) def _get_token_request(self): - return self._iam_token_service_pb2.CreateIamTokenRequest( - jwt=get_jwt( - self._account_id, - self._access_key_id, - self._private_key, - self._jwt_expiration_timeout, - ) + return self._iam_token_service_pb2.CreateIamTokenRequest(jwt=self._get_jwt()) + + +class NebiusJWTIamCredentials(OAuth2JwtTokenExchangeCredentials): + def __init__( + self, + account_id, + access_key_id, + private_key, + token_exchange_url=None, + ): + url = token_exchange_url + if url is None: + url = NEBIUS_CLOUD_IAM_TOKEN_EXCHANGE_URL + OAuth2JwtTokenExchangeCredentials.__init__( + self, + url, + account_id, + access_key_id, + private_key, + NEBIUS_CLOUD_JWT_ALGORITHM, + NEBIUS_CLOUD_IAM_TOKEN_SERVICE_AUDIENCE, + account_id, ) @@ -176,3 +268,20 @@ def __init__( iam_endpoint, iam_channel_credentials, ) + + +class NebiusServiceAccountCredentials(NebiusJWTIamCredentials): + def __init__( + self, + service_account_id, + access_key_id, + private_key, + iam_endpoint=None, + iam_channel_credentials=None, + ): + super(NebiusServiceAccountCredentials, self).__init__( + service_account_id, + access_key_id, + private_key, + iam_endpoint, + )