Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ dependencies = [
"boto3~=1.26.0",
"pydantic~=1.10.0",
"httpx~=0.23.0",
"orjson~=3.8.0",
"click~=8.1.0",
"tenacity==8.2.2",
"azure-identity==1.13.0",
Expand Down
2 changes: 1 addition & 1 deletion sync/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Library for leveraging the power of Sync"""
__version__ = "0.5.1"
__version__ = "0.5.2"

TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"
21 changes: 13 additions & 8 deletions sync/awsdatabricks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import json
import logging
from time import sleep
from typing import List, Tuple
from urllib.parse import urlparse

import boto3 as boto
import botocore
import orjson
from botocore.exceptions import ClientError

import sync._databricks
Expand Down Expand Up @@ -56,6 +56,7 @@
Response,
)
from sync.utils.dbfs import format_dbfs_filepath, write_dbfs_file
from sync.utils.json import DefaultDateTimeEncoder

__all__ = [
"get_access_report",
Expand Down Expand Up @@ -273,7 +274,7 @@ def _load_aws_cluster_info(cluster: dict) -> Tuple[Response[dict], Response[dict
cluster_info_file_response = _get_cluster_instances_from_dbfs(cluster_info_file_key)

cluster_info = (
orjson.loads(cluster_info_file_response) if cluster_info_file_response else None
json.loads(cluster_info_file_response) if cluster_info_file_response else None
)

# If this cluster does not have the "Sync agent" configured, attempt a best-effort snapshot of the instances that
Expand Down Expand Up @@ -409,12 +410,16 @@ def write_file(body: bytes):
all_timelines = retired_timelines + list(active_timelines_by_id.values())

write_file(
orjson.dumps(
{
"instances": list(all_inst_by_id.values()),
"instance_timelines": all_timelines,
"volumes": list(recorded_volumes_by_id.values()),
}
bytes(
json.dumps(
{
"instances": list(all_inst_by_id.values()),
"instance_timelines": all_timelines,
"volumes": list(recorded_volumes_by_id.values()),
},
cls=DefaultDateTimeEncoder,
),
"utf-8",
)
)
except Exception as e:
Expand Down
10 changes: 4 additions & 6 deletions sync/awsemr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import datetime
import io
import json
import logging
import re
from copy import deepcopy
Expand All @@ -12,7 +13,6 @@
from uuid import uuid4

import boto3 as boto
import orjson
from dateutil.parser import parse as dateparse

from sync import TIME_FORMAT
Expand All @@ -28,6 +28,7 @@
ProjectError,
Response,
)
from sync.utils.json import DateTimeEncoderNaiveUTCDropMicroseconds

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -364,7 +365,7 @@ def get_project_cluster_report( # noqa: C901
s3.download_fileobj(parsed_project_url.netloc, config_key, config)
return Response(
result=(
orjson.loads(config.getvalue().decode()),
json.loads(config.getvalue().decode()),
f"s3://{parsed_project_url.netloc}/{log_key}",
)
)
Expand Down Expand Up @@ -753,10 +754,7 @@ def _upload_object(obj: dict, s3_url: str) -> Response[str]:
s3 = boto.client("s3")
s3.upload_fileobj(
io.BytesIO(
orjson.dumps(
obj,
option=orjson.OPT_UTC_Z | orjson.OPT_OMIT_MICROSECONDS | orjson.OPT_NAIVE_UTC,
)
bytes(json.dumps(obj, cls=DateTimeEncoderNaiveUTCDropMicroseconds), "utf-8")
),
parsed_url.netloc,
obj_key,
Expand Down
23 changes: 13 additions & 10 deletions sync/azuredatabricks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import json
import logging
import os
import sys
from time import sleep
from typing import List, Dict, Type, TypeVar, Optional
from typing import Dict, List, Optional, Type, TypeVar
from urllib.parse import urlparse

import orjson
from azure.common.credentials import get_cli_profile
from azure.core.exceptions import ClientAuthenticationError
from azure.identity import DefaultAzureCredential
Expand Down Expand Up @@ -59,6 +59,7 @@
Response,
)
from sync.utils.dbfs import format_dbfs_filepath, write_dbfs_file
from sync.utils.json import DefaultDateTimeEncoder

__all__ = [
"get_access_report",
Expand Down Expand Up @@ -266,9 +267,7 @@ def _get_cluster_instances(cluster: dict) -> Response[dict]:
)

cluster_instances = (
orjson.loads(cluster_instances_file_response)
if cluster_instances_file_response
else None
json.loads(cluster_instances_file_response) if cluster_instances_file_response else None
)

# If this cluster does not have the "Sync agent" configured, attempt a best-effort snapshot of the instances that
Expand Down Expand Up @@ -371,11 +370,15 @@ def write_file(body: bytes):
all_timelines = retired_timelines + list(active_timelines_by_id.values())

write_file(
orjson.dumps(
{
"instances": list(all_vms_by_id.values()),
"timelines": all_timelines,
}
bytes(
json.dumps(
{
"instances": list(all_vms_by_id.values()),
"timelines": all_timelines,
},
cls=DefaultDateTimeEncoder,
),
"utf-8",
)
)

Expand Down
18 changes: 11 additions & 7 deletions sync/cli/_databricks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing import Tuple

import click
import orjson

from sync.api.projects import (
create_project_recommendation,
Expand All @@ -11,6 +11,7 @@
from sync.cli.util import validate_project
from sync.config import CONFIG
from sync.models import DatabricksComputeType, DatabricksPlanType, Platform, Preference
from sync.utils.json import DateTimeEncoderNaiveUTC

pass_platform = click.make_pass_decorator(Platform)

Expand Down Expand Up @@ -202,9 +203,10 @@ def get_recommendation(project: dict, recommendation_id: str):
click.echo("Recommendation generation failed.", err=True)
else:
click.echo(
orjson.dumps(
json.dumps(
recommendation,
option=orjson.OPT_INDENT_2 | orjson.OPT_NAIVE_UTC | orjson.OPT_UTC_Z,
indent=2,
cls=DateTimeEncoderNaiveUTC,
)
)
else:
Expand All @@ -223,9 +225,10 @@ def get_submission(project: dict, submission_id: str):
click.echo("Submission generation failed.", err=True)
else:
click.echo(
orjson.dumps(
json.dumps(
submission,
option=orjson.OPT_INDENT_2 | orjson.OPT_NAIVE_UTC | orjson.OPT_UTC_Z,
indent=2,
cls=DateTimeEncoderNaiveUTC,
)
)
else:
Expand Down Expand Up @@ -277,9 +280,10 @@ def get_cluster_report(
config = config_response.result
if config:
click.echo(
orjson.dumps(
json.dumps(
config.dict(exclude_none=True),
option=orjson.OPT_INDENT_2 | orjson.OPT_NAIVE_UTC | orjson.OPT_UTC_Z,
indent=2,
cls=DateTimeEncoderNaiveUTC,
)
)
else:
Expand Down
11 changes: 4 additions & 7 deletions sync/cli/awsemr.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import json
from io import TextIOWrapper
from typing import Dict

import click
import orjson

from sync import awsemr
from sync.api.predictions import get_prediction
from sync.cli.util import validate_project
from sync.config import CONFIG
from sync.models import Platform, Preference
from sync.utils.json import DateTimeEncoderNaiveUTC


@click.group
Expand All @@ -34,7 +35,7 @@ def run_job_flow(job_flow: TextIOWrapper, project: dict = None, region: str = No
"""Run a job flow

JOB_FLOW is a file containing the RunJobFlow request object"""
job_flow_obj = orjson.loads(job_flow.read())
job_flow_obj = json.loads(job_flow.read())

run_response = awsemr.run_and_record_job_flow(
job_flow_obj, project["id"] if project else None, region
Expand Down Expand Up @@ -125,11 +126,7 @@ def get_cluster_report(cluster_id: str, region: str = None):
config_response = awsemr.get_cluster_report(cluster_id, region)
config = config_response.result
if config:
click.echo(
orjson.dumps(
config, option=orjson.OPT_INDENT_2 | orjson.OPT_NAIVE_UTC | orjson.OPT_UTC_Z
)
)
click.echo(json.dumps(config, indent=2, cls=DateTimeEncoderNaiveUTC))
else:
click.echo(f"Failed to create prediction. {config_response.error}", err=True)

Expand Down
29 changes: 8 additions & 21 deletions sync/cli/predictions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import io
import json
from pathlib import Path
from urllib.parse import urlparse

import boto3 as boto
import click
import orjson

from sync.api.predictions import (
create_prediction,
Expand All @@ -17,6 +17,7 @@
from sync.cli.util import validate_project
from sync.config import CONFIG
from sync.models import Platform, Preference
from sync.utils.json import DateTimeEncoderNaiveUTCDropMicroseconds


@click.group
Expand Down Expand Up @@ -48,12 +49,12 @@ def generate(
parsed_report_arg = urlparse(report)
if parsed_report_arg.scheme == "":
with open(report) as report_fobj:
report = orjson.loads(report_fobj.read())
report = json.loads(report_fobj.read())
elif parsed_report_arg.scheme == "s3":
s3 = boto.client("s3")
report_io = io.BytesIO()
s3.download_fileobj(parsed_report_arg.netloc, parsed_report_arg.path.lstrip("/"), report_io)
report = orjson.loads(report_io.getvalue())
report = json.loads(report_io.getvalue())
else:
ctx.fail("Unsupported report argument")

Expand Down Expand Up @@ -83,13 +84,7 @@ def generate(
prediction = prediction_response.result
if prediction:
click.echo(
orjson.dumps(
prediction,
option=orjson.OPT_INDENT_2
| orjson.OPT_UTC_Z
| orjson.OPT_NAIVE_UTC
| orjson.OPT_OMIT_MICROSECONDS,
)
json.dumps(prediction, indent=2, cls=DateTimeEncoderNaiveUTCDropMicroseconds)
)
else:
click.echo(str(response.error), err=True)
Expand All @@ -108,12 +103,12 @@ def create(ctx: click.Context, platform: Platform, event_log: str, report: str,
parsed_report_arg = urlparse(report)
if parsed_report_arg.scheme == "":
with open(report) as report_fobj:
report = orjson.loads(report_fobj.read())
report = json.loads(report_fobj.read())
elif parsed_report_arg.scheme == "s3":
s3 = boto.client("s3")
report_io = io.BytesIO()
s3.download_fileobj(parsed_report_arg.netloc, parsed_report_arg.path.lstrip("/"), report_io)
report = orjson.loads(report_io.getvalue())
report = json.loads(report_io.getvalue())
else:
ctx.fail("Unsupported report argument")

Expand Down Expand Up @@ -161,15 +156,7 @@ def status(prediction_id: str):
def get(prediction_id: str, preference: Preference):
"""Retrieve a prediction"""
response = get_prediction(prediction_id, preference.value)
click.echo(
orjson.dumps(
response.result,
option=orjson.OPT_INDENT_2
| orjson.OPT_UTC_Z
| orjson.OPT_NAIVE_UTC
| orjson.OPT_OMIT_MICROSECONDS,
)
)
click.echo(json.dumps(response.result, indent=2, cls=DateTimeEncoderNaiveUTCDropMicroseconds))


@predictions.command
Expand Down
21 changes: 5 additions & 16 deletions sync/cli/projects.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json

import click
import orjson

from sync.api.projects import (
create_project,
Expand All @@ -12,6 +13,7 @@
from sync.cli.util import validate_project
from sync.config import CONFIG
from sync.models import Preference
from sync.utils.json import DateTimeEncoderNaiveUTCDropMicroseconds


@click.group
Expand Down Expand Up @@ -40,12 +42,7 @@ def get(project: dict):
response = get_project(project["id"])
project = response.result
if project:
click.echo(
orjson.dumps(
project,
option=orjson.OPT_INDENT_2 | orjson.OPT_UTC_Z | orjson.OPT_OMIT_MICROSECONDS,
)
)
click.echo(json.dumps(project, indent=2, cls=DateTimeEncoderNaiveUTCDropMicroseconds))
else:
click.echo(str(response.error), err=True)

Expand Down Expand Up @@ -183,14 +180,6 @@ def get_latest_prediction(project: dict, preference: Preference):
prediction_response = get_prediction(project["id"], preference)
prediction = prediction_response.result
if prediction:
click.echo(
orjson.dumps(
prediction,
option=orjson.OPT_INDENT_2
| orjson.OPT_UTC_Z
| orjson.OPT_NAIVE_UTC
| orjson.OPT_OMIT_MICROSECONDS,
)
)
click.echo(json.dumps(prediction, indent=2, cls=DateTimeEncoderNaiveUTCDropMicroseconds))
else:
click.echo(str(prediction_response.error), err=True)
Loading