Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ grpcio>=1.42.0
packaging
protobuf>=3.13.0,<5.0.0
aiohttp<4
pyjwt==2.8.0
1 change: 1 addition & 0 deletions test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ pylint-protobuf
cython
freezegun==1.2.2
pytest-cov
yandexcloud
-e .
58 changes: 58 additions & 0 deletions tests/aio/test_credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest
import time
import grpc
import threading

import tests.auth.test_credentials
import ydb.aio.iam


class TestServiceAccountCredentials(ydb.aio.iam.ServiceAccountCredentials):
def _channel_factory(self):
return grpc.aio.insecure_channel(self._iam_endpoint)

def get_expire_time(self):
return self._expires_in - time.time()


class TestNebiusServiceAccountCredentials(ydb.aio.iam.NebiusServiceAccountCredentials):
def get_expire_time(self):
return self._expires_in - time.time()


@pytest.mark.asyncio
async def test_yandex_service_account_credentials():
server = tests.auth.test_credentials.IamTokenServiceTestServer()
credentials = TestServiceAccountCredentials(
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
tests.auth.test_credentials.ACCESS_KEY_ID,
tests.auth.test_credentials.PRIVATE_KEY,
server.get_endpoint(),
)
t = (await credentials.auth_metadata())[0][1]
assert t == "test_token"
assert credentials.get_expire_time() <= 42
server.stop()


@pytest.mark.asyncio
async def test_nebius_service_account_credentials():
server = tests.auth.test_credentials.NebiusTokenServiceForTest()

def serve(s):
s.handle_request()

serve_thread = threading.Thread(target=serve, args=(server,))
serve_thread.start()

credentials = TestNebiusServiceAccountCredentials(
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
tests.auth.test_credentials.ACCESS_KEY_ID,
tests.auth.test_credentials.PRIVATE_KEY,
server.endpoint(),
)
t = (await credentials.auth_metadata())[0][1]
assert t == "test_nebius_token"
assert credentials.get_expire_time() <= 42

serve_thread.join()
Empty file added tests/auth/__init__.py
Empty file.
149 changes: 149 additions & 0 deletions tests/auth/test_credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import jwt
import concurrent.futures
import grpc
import time
import http.server
import urllib
import threading
import json

import ydb.iam

from yandex.cloud.iam.v1 import iam_token_service_pb2_grpc
from yandex.cloud.iam.v1 import iam_token_service_pb2

SERVICE_ACCOUNT_ID = "sa_id"
ACCESS_KEY_ID = "key_id"
PRIVATE_KEY = "-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC75/JS3rMcLJxv\nFgpOzF5+2gH+Yig3RE2MTl9uwC0BZKAv6foYr7xywQyWIK+W1cBhz8R4LfFmZo2j\nM0aCvdRmNBdW0EDSTnHLxCsFhoQWLVq+bI5f5jzkcoiioUtaEpADPqwgVULVtN/n\nnPJiZ6/dU30C3jmR6+LUgEntUtWt3eq3xQIn5lG3zC1klBY/HxtfH5Hu8xBvwRQT\nJnh3UpPLj8XwSmriDgdrhR7o6umWyVuGrMKlLHmeivlfzjYtfzO1MOIMG8t2/zxG\nR+xb4Vwks73sH1KruH/0/JMXU97npwpe+Um+uXhpldPygGErEia7abyZB2gMpXqr\nWYKMo02NAgMBAAECggEAO0BpC5OYw/4XN/optu4/r91bupTGHKNHlsIR2rDzoBhU\nYLd1evpTQJY6O07EP5pYZx9mUwUdtU4KRJeDGO/1/WJYp7HUdtxwirHpZP0lQn77\nuccuX/QQaHLrPekBgz4ONk+5ZBqukAfQgM7fKYOLk41jgpeDbM2Ggb6QUSsJISEp\nzrwpI/nNT/wn+Hvx4DxrzWU6wF+P8kl77UwPYlTA7GsT+T7eKGVH8xsxmK8pt6lg\nsvlBA5XosWBWUCGLgcBkAY5e4ZWbkdd183o+oMo78id6C+PQPE66PLDtHWfpRRmN\nm6XC03x6NVhnfvfozoWnmS4+e4qj4F/emCHvn0GMywKBgQDLXlj7YPFVXxZpUvg/\nrheVcCTGbNmQJ+4cZXx87huqwqKgkmtOyeWsRc7zYInYgraDrtCuDBCfP//ZzOh0\nLxepYLTPk5eNn/GT+VVrqsy35Ccr60g7Lp/bzb1WxyhcLbo0KX7/6jl0lP+VKtdv\nmto+4mbSBXSM1Y5BVVoVgJ3T/wKBgQDsiSvPRzVi5TTj13x67PFymTMx3HCe2WzH\nJUyepCmVhTm482zW95pv6raDr5CTO6OYpHtc5sTTRhVYEZoEYFTM9Vw8faBtluWG\nBjkRh4cIpoIARMn74YZKj0C/0vdX7SHdyBOU3bgRPHg08Hwu3xReqT1kEPSI/B2V\n4pe5fVrucwKBgQCNFgUxUA3dJjyMES18MDDYUZaRug4tfiYouRdmLGIxUxozv6CG\nZnbZzwxFt+GpvPUV4f+P33rgoCvFU+yoPctyjE6j+0aW0DFucPmb2kBwCu5J/856\nkFwCx3blbwFHAco+SdN7g2kcwgmV2MTg/lMOcU7XwUUcN0Obe7UlWbckzQKBgQDQ\nnXaXHL24GGFaZe4y2JFmujmNy1dEsoye44W9ERpf9h1fwsoGmmCKPp90az5+rIXw\nFXl8CUgk8lXW08db/r4r+ma8Lyx0GzcZyplAnaB5/6j+pazjSxfO4KOBy4Y89Tb+\nTP0AOcCi6ws13bgY+sUTa/5qKA4UVw+c5zlb7nRpgwKBgGXAXhenFw1666482iiN\ncHSgwc4ZHa1oL6aNJR1XWH+aboBSwR+feKHUPeT4jHgzRGo/aCNHD2FE5I8eBv33\nof1kWYjAO0YdzeKrW0rTwfvt9gGg+CS397aWu4cy+mTI+MNfBgeDAIVBeJOJXLlX\nhL8bFAuNNVrCOp79TNnNIsh7\n-----END PRIVATE KEY-----\n" # noqa: E501
PUBLIC_KEY = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAu+fyUt6zHCycbxYKTsxe\nftoB/mIoN0RNjE5fbsAtAWSgL+n6GK+8csEMliCvltXAYc/EeC3xZmaNozNGgr3U\nZjQXVtBA0k5xy8QrBYaEFi1avmyOX+Y85HKIoqFLWhKQAz6sIFVC1bTf55zyYmev\n3VN9At45kevi1IBJ7VLVrd3qt8UCJ+ZRt8wtZJQWPx8bXx+R7vMQb8EUEyZ4d1KT\ny4/F8Epq4g4Ha4Ue6OrplslbhqzCpSx5nor5X842LX8ztTDiDBvLdv88RkfsW+Fc\nJLO97B9Sq7h/9PyTF1Pe56cKXvlJvrl4aZXT8oBhKxImu2m8mQdoDKV6q1mCjKNN\njQIDAQAB\n-----END PUBLIC KEY-----\n" # noqa: E501


def test_metadata_credentials():
credentials = ydb.iam.MetadataUrlCredentials()
raised = False
try:
credentials.auth_metadata()
except Exception:
raised = True

assert raised


class IamTokenServiceForTest(iam_token_service_pb2_grpc.IamTokenServiceServicer):
def Create(self, request, context):
print("IAM token service request: {}".format(request))
# Validate jwt:
decoded = jwt.decode(
request.jwt, key=PUBLIC_KEY, algorithms=["PS256"], audience="https://iam.api.cloud.yandex.net/iam/v1/tokens"
)
assert decoded["iss"] == SERVICE_ACCOUNT_ID
assert decoded["aud"] == "https://iam.api.cloud.yandex.net/iam/v1/tokens"
assert abs(decoded["iat"] - time.time()) <= 60
assert abs(decoded["exp"] - time.time()) <= 3600

response = iam_token_service_pb2.CreateIamTokenResponse(iam_token="test_token")
response.expires_at.seconds = int(time.time() + 42)
return response


class IamTokenServiceTestServer(object):
def __init__(self):
self.server = grpc.server(concurrent.futures.ThreadPoolExecutor(max_workers=2))
iam_token_service_pb2_grpc.add_IamTokenServiceServicer_to_server(IamTokenServiceForTest(), self.server)
self.server.add_insecure_port(self.get_endpoint())
self.server.start()

def stop(self):
self.server.stop(1)
self.server.wait_for_termination()

def get_endpoint(self):
return "localhost:54321"


class TestServiceAccountCredentials(ydb.iam.ServiceAccountCredentials):
def _channel_factory(self):
return grpc.insecure_channel(self._iam_endpoint)

def get_expire_time(self):
return self._expires_in - time.time()


class TestNebiusServiceAccountCredentials(ydb.iam.NebiusServiceAccountCredentials):
def get_expire_time(self):
return self._expires_in - time.time()


class NebiusTokenServiceHandler(http.server.BaseHTTPRequestHandler):
def do_POST(self):
assert self.headers["Content-Type"] == "application/x-www-form-urlencoded"
assert self.path == "/token/exchange"
content_length = int(self.headers["Content-Length"])
post_data = self.rfile.read(content_length).decode("utf8")
print("NebiusTokenServiceHandler.POST data: {}".format(post_data))
parsed_request = urllib.parse.parse_qs(str(post_data))
assert len(parsed_request["grant_type"]) == 1
assert parsed_request["grant_type"][0] == "urn:ietf:params:oauth:grant-type:token-exchange"

assert len(parsed_request["requested_token_type"]) == 1
assert parsed_request["requested_token_type"][0] == "urn:ietf:params:oauth:token-type:access_token"

assert len(parsed_request["subject_token_type"]) == 1
assert parsed_request["subject_token_type"][0] == "urn:ietf:params:oauth:token-type:jwt"

assert len(parsed_request["subject_token"]) == 1
jwt_token = parsed_request["subject_token"][0]
decoded = jwt.decode(
jwt_token, key=PUBLIC_KEY, algorithms=["RS256"], audience="token-service.iam.new.nebiuscloud.net"
)
assert decoded["iss"] == SERVICE_ACCOUNT_ID
assert decoded["sub"] == SERVICE_ACCOUNT_ID
assert decoded["aud"] == "token-service.iam.new.nebiuscloud.net"
assert abs(decoded["iat"] - time.time()) <= 60
assert abs(decoded["exp"] - time.time()) <= 3600

response = {
"access_token": "test_nebius_token",
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
"token_type": "Bearer",
"expires_in": 42,
}

self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(response).encode("utf8"))


class NebiusTokenServiceForTest(http.server.HTTPServer):
def __init__(self):
http.server.HTTPServer.__init__(self, ("localhost", 54322), NebiusTokenServiceHandler)

def endpoint(self):
return "http://localhost:54322/token/exchange"


def test_yandex_service_account_credentials():
server = IamTokenServiceTestServer()
credentials = TestServiceAccountCredentials(SERVICE_ACCOUNT_ID, ACCESS_KEY_ID, PRIVATE_KEY, server.get_endpoint())
t = credentials.get_auth_token()
assert t == "test_token"
assert credentials.get_expire_time() <= 42
server.stop()


def test_nebius_service_account_credentials():
server = NebiusTokenServiceForTest()

def serve(s):
s.handle_request()

serve_thread = threading.Thread(target=serve, args=(server,))
serve_thread.start()

credentials = TestNebiusServiceAccountCredentials(SERVICE_ACCOUNT_ID, ACCESS_KEY_ID, PRIVATE_KEY, server.endpoint())
t = credentials.get_auth_token()
assert t == "test_nebius_token"
assert credentials.get_expire_time() <= 42

serve_thread.join()
11 changes: 0 additions & 11 deletions tests/table/test_tx.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,6 @@ def test_tx_begin(driver_sync, database):
tx.rollback()


def test_credentials():
credentials = ydb.iam.MetadataUrlCredentials()
raised = False
try:
credentials.auth_metadata()
except Exception:
raised = True

assert raised


def test_tx_snapshot_ro(driver_sync, database):
session = driver_sync.table_client.session().create()
description = (
Expand Down
109 changes: 99 additions & 10 deletions ydb/aio/iam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@
import logging
from ydb.iam import auth
from .credentials import AbstractExpiringTokenCredentials
from ydb import issues

logger = logging.getLogger(__name__)

try:
from yandex.cloud.iam.v1 import iam_token_service_pb2_grpc
from yandex.cloud.iam.v1 import iam_token_service_pb2
import jwt
except ImportError:
jwt = None

try:
from yandex.cloud.iam.v1 import iam_token_service_pb2_grpc
from yandex.cloud.iam.v1 import iam_token_service_pb2
except ImportError:
iam_token_service_pb2_grpc = None
iam_token_service_pb2 = None

Expand Down Expand Up @@ -55,6 +59,51 @@ async def _make_token_request(self):
IamTokenCredentials = TokenServiceCredentials


class OAuth2JwtTokenExchangeCredentials(AbstractExpiringTokenCredentials, auth.BaseJWTCredentials):
def __init__(
self,
token_exchange_url,
account_id,
access_key_id,
private_key,
algorithm,
token_service_url,
subject=None,
):
super(OAuth2JwtTokenExchangeCredentials, self).__init__()
auth.BaseJWTCredentials.__init__(
self, account_id, access_key_id, private_key, algorithm, token_service_url, subject
)
assert aiohttp is not None, "Install aiohttp library to use OAuth 2.0 token exchange credentials provider"
self._token_exchange_url = token_exchange_url

async def _make_token_request(self):
params = {
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
"requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
"subject_token": self._get_jwt(),
"subject_token_type": "urn:ietf:params:oauth:token-type:jwt",
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}

timeout = aiohttp.ClientTimeout(total=2)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(self._token_exchange_url, data=params, headers=headers) as response:
if response.status == 403:
raise issues.Unauthenticated(await response.text())
if response.status >= 500:
raise issues.Unavailable(await response.text())
if response.status >= 400:
raise issues.BadRequest(await response.text())
if response.status != 200:
raise issues.Error(await response.text())

response_json = await response.json()
access_token = response_json["access_token"]
expires_in = response_json["expires_in"]
return {"access_token": access_token, "expires_in": expires_in}


class JWTIamCredentials(TokenServiceCredentials, auth.BaseJWTCredentials):
def __init__(
self,
Expand All @@ -65,16 +114,39 @@ def __init__(
iam_channel_credentials=None,
):
TokenServiceCredentials.__init__(self, iam_endpoint, iam_channel_credentials)
auth.BaseJWTCredentials.__init__(self, account_id, access_key_id, private_key)
auth.BaseJWTCredentials.__init__(
self,
account_id,
access_key_id,
private_key,
auth.YANDEX_CLOUD_JWT_ALGORITHM,
auth.YANDEX_CLOUD_IAM_TOKEN_SERVICE_URL,
)

def _get_token_request(self):
return iam_token_service_pb2.CreateIamTokenRequest(
jwt=auth.get_jwt(
self._account_id,
self._access_key_id,
self._private_key,
self._jwt_expiration_timeout,
)
return iam_token_service_pb2.CreateIamTokenRequest(jwt=self._get_jwt())


class NebiusJWTIamCredentials(OAuth2JwtTokenExchangeCredentials):
def __init__(
self,
account_id,
access_key_id,
private_key,
token_exchange_url=None,
):
url = token_exchange_url
if url is None:
url = auth.NEBIUS_CLOUD_IAM_TOKEN_EXCHANGE_URL
OAuth2JwtTokenExchangeCredentials.__init__(
self,
url,
account_id,
access_key_id,
private_key,
auth.NEBIUS_CLOUD_JWT_ALGORITHM,
auth.NEBIUS_CLOUD_IAM_TOKEN_SERVICE_AUDIENCE,
account_id,
)


Expand Down Expand Up @@ -130,3 +202,20 @@ def __init__(
iam_endpoint,
iam_channel_credentials,
)


class NebiusServiceAccountCredentials(NebiusJWTIamCredentials):
def __init__(
self,
service_account_id,
access_key_id,
private_key,
iam_endpoint=None,
iam_channel_credentials=None,
):
super(NebiusServiceAccountCredentials, self).__init__(
service_account_id,
access_key_id,
private_key,
iam_endpoint,
)
7 changes: 7 additions & 0 deletions ydb/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ def credentials_from_env_variables(tracer=None):

return ydb.iam.ServiceAccountCredentials.from_file(service_account_key_file)

nebius_service_account_key_file = os.getenv("YDB_NEBIUS_SERVICE_ACCOUNT_KEY_FILE_CREDENTIALS")
if nebius_service_account_key_file is not None:
ctx.trace({"credentials.nebius_service_account_key_file": True})
import ydb.iam

return ydb.iam.NebiusServiceAccountCredentials.from_file(nebius_service_account_key_file)

anonymous_credetials = os.getenv("YDB_ANONYMOUS_CREDENTIALS", "0") == "1"
if anonymous_credetials:
ctx.trace({"credentials.anonymous": True})
Expand Down
1 change: 1 addition & 0 deletions ydb/iam/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
from .auth import ServiceAccountCredentials # noqa
from .auth import NebiusServiceAccountCredentials # noqa
from .auth import MetadataUrlCredentials # noqa
Loading