Skip to content

Commit

Permalink
GML-1660 add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Lu Zhou authored and Lu Zhou committed Jun 7, 2024
1 parent 3fef8fc commit 1ad4209
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 73 deletions.
33 changes: 10 additions & 23 deletions pyTigerGraph/pyTigerGraphBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,35 +219,22 @@ def _set_auth_header(self):

def _verify_jwt_token_support(self):
try:
# Check JWT support for RestPP server
logger.debug("Attempting to verify JWT token support with getVer() on RestPP server.")
logger.debug(f"Using auth header: {self.authHeader}")
version = self.getVer()
logger.info(f"Database version: {version}")
except requests.exceptions.HTTPError as e:
if e.response.status_code == 403:
logger.error(f"Unauthorized error: {e}. The JWT token might be invalid or expired.")
else:
logger.error(f"HTTP error occurred: {e}")
raise
except Exception as e:
logger.error(f"Error occurred: {e}. The DB version using doesn't support JWT token for RestPP.")
logger.error("Please switch to API token or username/password.")
raise RuntimeError("The DB version using doesn't support JWT token for RestPP. Please switch to API token or username/password.") from e

# Check JWT support for GSQL server
try:
logger.debug(f"Attempting to get schema with URL: {self.gsUrl + '/gsqlserver/gsql/simpleauth'}")
logger.debug(f"Using auth header: {self.authHeader}")
self._get(self.gsUrl + "/gsqlserver/gsql/simpleauth", authMode="token", resKey=None)
# Check JWT support for GSQL server
logger.debug(f"Attempting to get auth info with URL: {self.gsUrl + '/gsqlserver/gsql/simpleauth'}")
self._get(f"{self.gsUrl}/gsqlserver/gsql/simpleauth", authMode="token", resKey=None)
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
logger.error(f"Unauthorized error: {e}. The JWT token might be invalid or expired.")
else:
logger.error("The DB version using doesn't support JWT token for GSQL. Please switch to API token or username/password.")
raise
logger.error(f"HTTP error: {e}. The JWT token might be invalid or expired.")
raise RuntimeError(f"HTTP error: {e}, The JWT token might be invalid or expired.") from e
except Exception as e:
logger.error(f"Error occurred in _get request: {e}. The DB version using doesn't support JWT token for GSQL.")
logger.error("Please switch to API token or username/password.")
raise
message = "The DB version using doesn't support JWT token. Please switch to API token or username/password."
logger.error(f"Error occurred: {e}. {message}")
raise RuntimeError(message) from e

def _locals(self, _locals: dict) -> str:
del _locals["self"]
Expand Down
94 changes: 44 additions & 50 deletions tests/test_jwtAuth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import requests
import json
from requests.auth import HTTPBasicAuth

import unittest
Expand All @@ -8,69 +9,43 @@
from tests.pyTigerGraphUnitTest import make_connection

class TestJWTTokenAuth(unittest.TestCase):
# @classmethod
# def setUpClass(cls):
# cls.conn = make_connection()

@classmethod
def setUpClass(cls):
cls.conn = MagicMock()
cls.conn.host = "http://localhost"
cls.conn.gsPort = 14240
cls.conn.username = "tigergraph"
cls.conn.password = "tigergraph"

def test_jwtauth(self):
# version = self.conn.getVer()

# if "4.1" in version:
# self.test_jwtauth_4_1()
# elif "3.9" in version:
# self.test_jwtauth_3_9()
# else:
# pass ## todo: don't have a good way to test on 3.10.0, since there is no endpoint to request and configure jwt token
with patch.object(self.conn, 'getVer', return_value="4.1"):
self.test_jwtauth_4_1()

with patch.object(self.conn, 'getVer', return_value="3.9"):
self.test_jwtauth_3_9()

cls.conn = make_connection()

def requestJWTToken(self):
# Define the URL
url = f"{self.conn.host}:{self.conn.gsPort}gsqlserver/requestjwttoken"

url = f"{self.conn.host}:{self.conn.gsPort}/gsqlserver/requestjwttoken"
# Define the data payload
payload = {"lifetime": "1000000000"}

payload = json.dumps({"lifetime": "1000000000"})
# Define the headers for the request
headers = {
'Content-Type': 'application/json'
}
# Make the POST request with basic authentication
response = requests.post(url, json=payload, auth=HTTPBasicAuth(self.conn.username, self.conn.password))
return response.json()["jwt"]
response = requests.post(url, data=payload, headers=headers, auth=(self.conn.username, self.conn.password))
return response.json()['token']

def test_jwtauth(self):
dbversion = self.conn.getVer()
if "3.9" in dbversion:
self.test_jwtauth_3_9()
elif "4.1" in dbversion:
self.test_jwtauth_4_1_success()
self.test_jwtauth_4_1_fail()

@patch('pyTigerGraph.TigerGraphConnection.getVer', return_value="3.9")
def test_jwtauth_3_9(self, mock_getVer):
def test_jwtauth_3_9(self):
with self.assertRaises(RuntimeError) as context:
TigerGraphConnection(
conn = TigerGraphConnection(
host=self.conn.host,
jwtToken="fake.JWT.Token"
)

self.assertIn("The DB version using doesn't support JWT token for RestPP.", str(context.exception))

@patch('pyTigerGraph.TigerGraphConnection.getVer', return_value="4.1")
@patch('requests.post')
@patch('pyTigerGraph.TigerGraphConnection._get')
def test_jwtauth_4_1(self, mock_get, mock_post, mock_getVer):

# Mock the response for requestJWTToken
jwt_resposne = {"jwt": "valid.JWT.Token"}
# mock_post.return_value = Mock(status_code=200, json=lambda: {"jwt": "valid.JWT.Token"})
# Verify the exception message
self.assertIn("HTTP error", str(context.exception))

# Mock the response for _get (RestPP endpoint)
mock_get.return_value = {"privileges": "some_privileges"}

# Test JWT token on 4.1
# jwt_token = self.requestJWTToken()
jwt_token = "valid.JWT.Token"
def test_jwtauth_4_1_success(self):
jwt_token = self.requestJWTToken()

conn = TigerGraphConnection(
host=self.conn.host,
Expand All @@ -85,5 +60,24 @@ def test_jwtauth_4_1(self, mock_get, mock_post, mock_getVer):
res = conn._get(f"http://{self.conn.host}:{self.conn.gsPort}/gsqlserver/gsql/simpleauth", authMode="token", resKey=None)
self.assertIn("privileges", res)

def test_jwtauth_4_1_fail(self):
jwt_token = self.requestJWTToken()

with self.assertRaises(RuntimeError) as context:
conn = TigerGraphConnection(
host=self.conn.host,
jwtToken=jwt_token
)

# restpp on port 9000
conn.getVer()

# gsql on port 14240
conn._get(f"http://{self.conn.host}:{self.conn.gsPort}/gsqlserver/gsql/simpleauth", authMode="token", resKey=None)

# Verify the exception message
self.assertIn("The JWT token might be invalid or expired", str(context.exception))


if __name__ == '__main__':
unittest.main()

0 comments on commit 1ad4209

Please sign in to comment.