diff --git a/pyproject.toml b/pyproject.toml index 37aad62..20663e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,6 @@ dependencies = [ "boto3~=1.26.0", "pydantic~=1.10.0", "httpx~=0.23.0", - "orjson~=3.8.0", "click~=8.1.0", "tenacity==8.2.2", "azure-identity==1.13.0", diff --git a/sync/__init__.py b/sync/__init__.py index 57e6122..c44a4a9 100644 --- a/sync/__init__.py +++ b/sync/__init__.py @@ -1,4 +1,4 @@ """Library for leveraging the power of Sync""" -__version__ = "0.5.1" +__version__ = "0.5.2" TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" diff --git a/sync/awsdatabricks.py b/sync/awsdatabricks.py index b7fcb89..2ad9f5c 100644 --- a/sync/awsdatabricks.py +++ b/sync/awsdatabricks.py @@ -1,3 +1,4 @@ +import json import logging from time import sleep from typing import List, Tuple @@ -5,7 +6,6 @@ import boto3 as boto import botocore -import orjson from botocore.exceptions import ClientError import sync._databricks @@ -56,6 +56,7 @@ Response, ) from sync.utils.dbfs import format_dbfs_filepath, write_dbfs_file +from sync.utils.json import DefaultDateTimeEncoder __all__ = [ "get_access_report", @@ -273,7 +274,7 @@ def _load_aws_cluster_info(cluster: dict) -> Tuple[Response[dict], Response[dict cluster_info_file_response = _get_cluster_instances_from_dbfs(cluster_info_file_key) cluster_info = ( - orjson.loads(cluster_info_file_response) if cluster_info_file_response else None + json.loads(cluster_info_file_response) if cluster_info_file_response else None ) # If this cluster does not have the "Sync agent" configured, attempt a best-effort snapshot of the instances that @@ -409,12 +410,16 @@ def write_file(body: bytes): all_timelines = retired_timelines + list(active_timelines_by_id.values()) write_file( - orjson.dumps( - { - "instances": list(all_inst_by_id.values()), - "instance_timelines": all_timelines, - "volumes": list(recorded_volumes_by_id.values()), - } + bytes( + json.dumps( + { + "instances": list(all_inst_by_id.values()), + "instance_timelines": all_timelines, + "volumes": list(recorded_volumes_by_id.values()), + }, + cls=DefaultDateTimeEncoder, + ), + "utf-8", ) ) except Exception as e: diff --git a/sync/awsemr.py b/sync/awsemr.py index 42de050..a8f0035 100644 --- a/sync/awsemr.py +++ b/sync/awsemr.py @@ -4,6 +4,7 @@ import datetime import io +import json import logging import re from copy import deepcopy @@ -12,7 +13,6 @@ from uuid import uuid4 import boto3 as boto -import orjson from dateutil.parser import parse as dateparse from sync import TIME_FORMAT @@ -28,6 +28,7 @@ ProjectError, Response, ) +from sync.utils.json import DateTimeEncoderNaiveUTCDropMicroseconds logger = logging.getLogger(__name__) @@ -364,7 +365,7 @@ def get_project_cluster_report( # noqa: C901 s3.download_fileobj(parsed_project_url.netloc, config_key, config) return Response( result=( - orjson.loads(config.getvalue().decode()), + json.loads(config.getvalue().decode()), f"s3://{parsed_project_url.netloc}/{log_key}", ) ) @@ -753,10 +754,7 @@ def _upload_object(obj: dict, s3_url: str) -> Response[str]: s3 = boto.client("s3") s3.upload_fileobj( io.BytesIO( - orjson.dumps( - obj, - option=orjson.OPT_UTC_Z | orjson.OPT_OMIT_MICROSECONDS | orjson.OPT_NAIVE_UTC, - ) + bytes(json.dumps(obj, cls=DateTimeEncoderNaiveUTCDropMicroseconds), "utf-8") ), parsed_url.netloc, obj_key, diff --git a/sync/azuredatabricks.py b/sync/azuredatabricks.py index ab5b87d..d9218d5 100644 --- a/sync/azuredatabricks.py +++ b/sync/azuredatabricks.py @@ -1,11 +1,11 @@ +import json import logging import os import sys from time import sleep -from typing import List, Dict, Type, TypeVar, Optional +from typing import Dict, List, Optional, Type, TypeVar from urllib.parse import urlparse -import orjson from azure.common.credentials import get_cli_profile from azure.core.exceptions import ClientAuthenticationError from azure.identity import DefaultAzureCredential @@ -59,6 +59,7 @@ Response, ) from sync.utils.dbfs import format_dbfs_filepath, write_dbfs_file +from sync.utils.json import DefaultDateTimeEncoder __all__ = [ "get_access_report", @@ -266,9 +267,7 @@ def _get_cluster_instances(cluster: dict) -> Response[dict]: ) cluster_instances = ( - orjson.loads(cluster_instances_file_response) - if cluster_instances_file_response - else None + json.loads(cluster_instances_file_response) if cluster_instances_file_response else None ) # If this cluster does not have the "Sync agent" configured, attempt a best-effort snapshot of the instances that @@ -371,11 +370,15 @@ def write_file(body: bytes): all_timelines = retired_timelines + list(active_timelines_by_id.values()) write_file( - orjson.dumps( - { - "instances": list(all_vms_by_id.values()), - "timelines": all_timelines, - } + bytes( + json.dumps( + { + "instances": list(all_vms_by_id.values()), + "timelines": all_timelines, + }, + cls=DefaultDateTimeEncoder, + ), + "utf-8", ) ) diff --git a/sync/cli/_databricks.py b/sync/cli/_databricks.py index 0a9b2aa..c0857e6 100644 --- a/sync/cli/_databricks.py +++ b/sync/cli/_databricks.py @@ -1,7 +1,7 @@ +import json from typing import Tuple import click -import orjson from sync.api.projects import ( create_project_recommendation, @@ -11,6 +11,7 @@ from sync.cli.util import validate_project from sync.config import CONFIG from sync.models import DatabricksComputeType, DatabricksPlanType, Platform, Preference +from sync.utils.json import DateTimeEncoderNaiveUTC pass_platform = click.make_pass_decorator(Platform) @@ -202,9 +203,10 @@ def get_recommendation(project: dict, recommendation_id: str): click.echo("Recommendation generation failed.", err=True) else: click.echo( - orjson.dumps( + json.dumps( recommendation, - option=orjson.OPT_INDENT_2 | orjson.OPT_NAIVE_UTC | orjson.OPT_UTC_Z, + indent=2, + cls=DateTimeEncoderNaiveUTC, ) ) else: @@ -223,9 +225,10 @@ def get_submission(project: dict, submission_id: str): click.echo("Submission generation failed.", err=True) else: click.echo( - orjson.dumps( + json.dumps( submission, - option=orjson.OPT_INDENT_2 | orjson.OPT_NAIVE_UTC | orjson.OPT_UTC_Z, + indent=2, + cls=DateTimeEncoderNaiveUTC, ) ) else: @@ -277,9 +280,10 @@ def get_cluster_report( config = config_response.result if config: click.echo( - orjson.dumps( + json.dumps( config.dict(exclude_none=True), - option=orjson.OPT_INDENT_2 | orjson.OPT_NAIVE_UTC | orjson.OPT_UTC_Z, + indent=2, + cls=DateTimeEncoderNaiveUTC, ) ) else: diff --git a/sync/cli/awsemr.py b/sync/cli/awsemr.py index d0b90c0..26cbc1a 100644 --- a/sync/cli/awsemr.py +++ b/sync/cli/awsemr.py @@ -1,14 +1,15 @@ +import json from io import TextIOWrapper from typing import Dict import click -import orjson from sync import awsemr from sync.api.predictions import get_prediction from sync.cli.util import validate_project from sync.config import CONFIG from sync.models import Platform, Preference +from sync.utils.json import DateTimeEncoderNaiveUTC @click.group @@ -34,7 +35,7 @@ def run_job_flow(job_flow: TextIOWrapper, project: dict = None, region: str = No """Run a job flow JOB_FLOW is a file containing the RunJobFlow request object""" - job_flow_obj = orjson.loads(job_flow.read()) + job_flow_obj = json.loads(job_flow.read()) run_response = awsemr.run_and_record_job_flow( job_flow_obj, project["id"] if project else None, region @@ -125,11 +126,7 @@ def get_cluster_report(cluster_id: str, region: str = None): config_response = awsemr.get_cluster_report(cluster_id, region) config = config_response.result if config: - click.echo( - orjson.dumps( - config, option=orjson.OPT_INDENT_2 | orjson.OPT_NAIVE_UTC | orjson.OPT_UTC_Z - ) - ) + click.echo(json.dumps(config, indent=2, cls=DateTimeEncoderNaiveUTC)) else: click.echo(f"Failed to create prediction. {config_response.error}", err=True) diff --git a/sync/cli/predictions.py b/sync/cli/predictions.py index 2bcf552..8575ffb 100644 --- a/sync/cli/predictions.py +++ b/sync/cli/predictions.py @@ -1,10 +1,10 @@ import io +import json from pathlib import Path from urllib.parse import urlparse import boto3 as boto import click -import orjson from sync.api.predictions import ( create_prediction, @@ -17,6 +17,7 @@ from sync.cli.util import validate_project from sync.config import CONFIG from sync.models import Platform, Preference +from sync.utils.json import DateTimeEncoderNaiveUTCDropMicroseconds @click.group @@ -48,12 +49,12 @@ def generate( parsed_report_arg = urlparse(report) if parsed_report_arg.scheme == "": with open(report) as report_fobj: - report = orjson.loads(report_fobj.read()) + report = json.loads(report_fobj.read()) elif parsed_report_arg.scheme == "s3": s3 = boto.client("s3") report_io = io.BytesIO() s3.download_fileobj(parsed_report_arg.netloc, parsed_report_arg.path.lstrip("/"), report_io) - report = orjson.loads(report_io.getvalue()) + report = json.loads(report_io.getvalue()) else: ctx.fail("Unsupported report argument") @@ -83,13 +84,7 @@ def generate( prediction = prediction_response.result if prediction: click.echo( - orjson.dumps( - prediction, - option=orjson.OPT_INDENT_2 - | orjson.OPT_UTC_Z - | orjson.OPT_NAIVE_UTC - | orjson.OPT_OMIT_MICROSECONDS, - ) + json.dumps(prediction, indent=2, cls=DateTimeEncoderNaiveUTCDropMicroseconds) ) else: click.echo(str(response.error), err=True) @@ -108,12 +103,12 @@ def create(ctx: click.Context, platform: Platform, event_log: str, report: str, parsed_report_arg = urlparse(report) if parsed_report_arg.scheme == "": with open(report) as report_fobj: - report = orjson.loads(report_fobj.read()) + report = json.loads(report_fobj.read()) elif parsed_report_arg.scheme == "s3": s3 = boto.client("s3") report_io = io.BytesIO() s3.download_fileobj(parsed_report_arg.netloc, parsed_report_arg.path.lstrip("/"), report_io) - report = orjson.loads(report_io.getvalue()) + report = json.loads(report_io.getvalue()) else: ctx.fail("Unsupported report argument") @@ -161,15 +156,7 @@ def status(prediction_id: str): def get(prediction_id: str, preference: Preference): """Retrieve a prediction""" response = get_prediction(prediction_id, preference.value) - click.echo( - orjson.dumps( - response.result, - option=orjson.OPT_INDENT_2 - | orjson.OPT_UTC_Z - | orjson.OPT_NAIVE_UTC - | orjson.OPT_OMIT_MICROSECONDS, - ) - ) + click.echo(json.dumps(response.result, indent=2, cls=DateTimeEncoderNaiveUTCDropMicroseconds)) @predictions.command diff --git a/sync/cli/projects.py b/sync/cli/projects.py index 1e110b6..cf483d2 100644 --- a/sync/cli/projects.py +++ b/sync/cli/projects.py @@ -1,5 +1,6 @@ +import json + import click -import orjson from sync.api.projects import ( create_project, @@ -12,6 +13,7 @@ from sync.cli.util import validate_project from sync.config import CONFIG from sync.models import Preference +from sync.utils.json import DateTimeEncoderNaiveUTCDropMicroseconds @click.group @@ -40,12 +42,7 @@ def get(project: dict): response = get_project(project["id"]) project = response.result if project: - click.echo( - orjson.dumps( - project, - option=orjson.OPT_INDENT_2 | orjson.OPT_UTC_Z | orjson.OPT_OMIT_MICROSECONDS, - ) - ) + click.echo(json.dumps(project, indent=2, cls=DateTimeEncoderNaiveUTCDropMicroseconds)) else: click.echo(str(response.error), err=True) @@ -183,14 +180,6 @@ def get_latest_prediction(project: dict, preference: Preference): prediction_response = get_prediction(project["id"], preference) prediction = prediction_response.result if prediction: - click.echo( - orjson.dumps( - prediction, - option=orjson.OPT_INDENT_2 - | orjson.OPT_UTC_Z - | orjson.OPT_NAIVE_UTC - | orjson.OPT_OMIT_MICROSECONDS, - ) - ) + click.echo(json.dumps(prediction, indent=2, cls=DateTimeEncoderNaiveUTCDropMicroseconds)) else: click.echo(str(prediction_response.error), err=True) diff --git a/sync/cli/workspaces.py b/sync/cli/workspaces.py index a3367c9..4d25297 100644 --- a/sync/cli/workspaces.py +++ b/sync/cli/workspaces.py @@ -1,10 +1,12 @@ +import json + import click -import orjson from sync.api import workspace from sync.cli.util import OPTIONAL_DEFAULT, validate_project from sync.config import API_KEY, DB_CONFIG from sync.models import DatabricksPlanType +from sync.utils.json import DateTimeEncoderNaiveUTCDropMicroseconds @click.group @@ -75,12 +77,7 @@ def create_workspace_config( ) config = response.result if config: - click.echo( - orjson.dumps( - config, - option=orjson.OPT_INDENT_2 | orjson.OPT_UTC_Z | orjson.OPT_OMIT_MICROSECONDS, - ) - ) + click.echo(json.dumps(config, indent=2, cls=DateTimeEncoderNaiveUTCDropMicroseconds)) else: click.echo(str(response.error), err=True) @@ -92,9 +89,10 @@ def get_workspace_config(workspace_id: str): config = config_response.result if config: click.echo( - orjson.dumps( + json.dumps( config, - option=orjson.OPT_INDENT_2 | orjson.OPT_UTC_Z | orjson.OPT_OMIT_MICROSECONDS, + indent=2, + cls=DateTimeEncoderNaiveUTCDropMicroseconds, ) ) else: @@ -173,9 +171,10 @@ def update_workspace_config( config = update_config_response.result if config: click.echo( - orjson.dumps( + json.dumps( config, - option=orjson.OPT_INDENT_2 | orjson.OPT_UTC_Z | orjson.OPT_OMIT_MICROSECONDS, + indent=2, + cls=DateTimeEncoderNaiveUTCDropMicroseconds, ) ) else: @@ -205,9 +204,10 @@ def reset_webhook_creds(workspace_id: str): result = response.result if result: click.echo( - orjson.dumps( + json.dumps( result, - option=orjson.OPT_INDENT_2 | orjson.OPT_UTC_Z | orjson.OPT_OMIT_MICROSECONDS, + indent=2, + cls=DateTimeEncoderNaiveUTCDropMicroseconds, ) ) else: diff --git a/sync/clients/__init__.py b/sync/clients/__init__.py index d8ecf62..ff1a5d8 100644 --- a/sync/clients/__init__.py +++ b/sync/clients/__init__.py @@ -1,23 +1,24 @@ +import json +from typing import Set, Tuple, Union + import httpx -import orjson from tenacity import Retrying, TryAgain, stop_after_attempt, wait_exponential_jitter -from typing import Tuple, Union, Set from sync import __version__ +from sync.utils.json import DateTimeEncoderNaiveUTCDropMicroseconds USER_AGENT = f"Sync Library/{__version__} (syncsparkpy)" def encode_json(obj: dict) -> Tuple[dict, str]: # "%Y-%m-%dT%H:%M:%SZ" - options = orjson.OPT_UTC_Z | orjson.OPT_OMIT_MICROSECONDS | orjson.OPT_NAIVE_UTC - json = orjson.dumps(obj, option=options).decode() + json_obj = json.dumps(obj, cls=DateTimeEncoderNaiveUTCDropMicroseconds) return { - "Content-Length": str(len(json)), + "Content-Length": str(len(json_obj)), "Content-Type": "application/json", - }, json + }, json_obj class RetryableHTTPClient: diff --git a/sync/utils/json.py b/sync/utils/json.py new file mode 100644 index 0000000..fc783ca --- /dev/null +++ b/sync/utils/json.py @@ -0,0 +1,36 @@ +import datetime +from json import JSONEncoder + + +class DefaultDateTimeEncoder(JSONEncoder): + # this copies orjson's default behavior when serializing datetimes + def default(self, obj): + if isinstance(obj, datetime.datetime): + date = obj + date = date.isoformat() + return date + + +class DateTimeEncoderNaiveUTC(JSONEncoder): + # this copies orjson's behavior when used with the options OPT_UTC_Z and OPT_NAIVE_UTC + def default(self, obj): + if isinstance(obj, datetime.datetime): + date = obj + if date.tzinfo is None: + date = date.replace(tzinfo=datetime.timezone.utc) + date = date.isoformat() + date = date.replace("+00:00", "Z") + return date + + +class DateTimeEncoderNaiveUTCDropMicroseconds(JSONEncoder): + # this copies orjson's behavior when used with the options OPT_OMIT_MICROSECONDS, OPT_UTC_Z, and OPT_NAIVE_UTC + def default(self, obj): + if isinstance(obj, datetime.datetime): + date = obj + date = date.replace(microsecond=0) + if date.tzinfo is None: + date = date.replace(tzinfo=datetime.timezone.utc) + date = date.isoformat() + date = date.replace("+00:00", "Z") + return date diff --git a/tests/api/test_predictions.py b/tests/api/test_predictions.py index bfa96ce..15bb900 100644 --- a/tests/api/test_predictions.py +++ b/tests/api/test_predictions.py @@ -1,4 +1,5 @@ -import orjson +import json + import respx from httpx import Response @@ -52,7 +53,7 @@ def test_get_prediction(): with open("tests/data/predictions_response.json") as predictions_fobj: prediction = [ p - for p in orjson.loads(predictions_fobj.read())["result"] + for p in json.loads(predictions_fobj.read())["result"] if p["prediction_id"] == prediction_id ][0] diff --git a/tests/asyncapi/test_asyncpredictions.py b/tests/asyncapi/test_asyncpredictions.py index 0ed2922..2625583 100644 --- a/tests/asyncapi/test_asyncpredictions.py +++ b/tests/asyncapi/test_asyncpredictions.py @@ -1,4 +1,5 @@ -import orjson +import json + import pytest import respx from httpx import Response @@ -33,7 +34,7 @@ async def test_generate_prediction(): with open("tests/data/predictions_response.json") as predictions_fobj: prediction = [ p - for p in orjson.loads(predictions_fobj.read())["result"] + for p in json.loads(predictions_fobj.read())["result"] if p["prediction_id"] == prediction_id ][0] mock_router.get(f"/v1/autotuner/predictions/{prediction_id}").mock( diff --git a/tests/test_awsdatabricks.py b/tests/test_awsdatabricks.py index ebceae7..c586d2b 100644 --- a/tests/test_awsdatabricks.py +++ b/tests/test_awsdatabricks.py @@ -1,11 +1,11 @@ import copy import io +import json from datetime import datetime from unittest.mock import Mock, patch from uuid import uuid4 import boto3 as boto -import orjson from botocore.response import StreamingBody from botocore.stub import Stubber from httpx import Response @@ -14,6 +14,7 @@ from sync.config import DatabricksConf from sync.models import DatabricksAPIError, DatabricksError from sync.models import Response as SyncResponse +from sync.utils.json import DateTimeEncoderNaiveUTCDropMicroseconds MOCK_RUN = { "job_id": 12345678910, @@ -812,14 +813,17 @@ def test_create_prediction_for_run_success_with_cluster_instance_file(respx_mock s3 = boto.client("s3") s3_stubber = Stubber(s3) - mock_cluster_info_bytes = orjson.dumps( - { - "volumes": MOCK_VOLUMES["Volumes"], - "instances": [ - inst for res in MOCK_INSTANCES["Reservations"] for inst in res["Instances"] - ], - }, - option=orjson.OPT_UTC_Z | orjson.OPT_OMIT_MICROSECONDS | orjson.OPT_NAIVE_UTC, + mock_cluster_info_bytes = bytes( + json.dumps( + { + "volumes": MOCK_VOLUMES["Volumes"], + "instances": [ + inst for res in MOCK_INSTANCES["Reservations"] for inst in res["Instances"] + ], + }, + cls=DateTimeEncoderNaiveUTCDropMicroseconds, + ), + "utf-8", ) s3_stubber.add_response( "get_object", diff --git a/tests/test_awsemr.py b/tests/test_awsemr.py index 3a98e07..862a341 100644 --- a/tests/test_awsemr.py +++ b/tests/test_awsemr.py @@ -1,7 +1,7 @@ +import json from unittest.mock import Mock, patch import boto3 as boto -import orjson from botocore.stub import ANY, Stubber from dateutil.parser import parse from deepdiff import DeepDiff @@ -20,7 +20,7 @@ def test_create_prediction(create_prediction, get_cluster_report): with open("tests/data/emr-cluster-report.json") as emr_cluster_report_fobj: get_cluster_report.return_value = Response( - result=orjson.loads(emr_cluster_report_fobj.read()) + result=json.loads(emr_cluster_report_fobj.read()) ) prediction_id = "320554b0-3972-4b7c-9e41-c8efdbdc042c" @@ -62,7 +62,7 @@ def test_create_prediction(create_prediction, get_cluster_report): def test_get_cluster_report(): with open("tests/data/emr-cluster-report.json") as emr_cluster_report_fobj: - emr_cluster_report = orjson.loads(emr_cluster_report_fobj.read()) + emr_cluster_report = json.loads(emr_cluster_report_fobj.read()) cluster_id = emr_cluster_report["Cluster"]["Id"] region = emr_cluster_report["Region"] @@ -103,7 +103,7 @@ def test_get_cluster_report(): @patch("sync.awsemr.get_project") def test_get_project_report(get_project, get_cluster_report): with open("tests/data/emr-cluster-report.json") as emr_cluster_report_fobj: - cluster_report = orjson.loads(emr_cluster_report_fobj.read()) + cluster_report = json.loads(emr_cluster_report_fobj.read()) get_cluster_report.return_value = Response(result=cluster_report) get_project.return_value = Response(