Skip to content

Commit

Permalink
Improve serialization logic for charts, text, and htmls using jsonl
Browse files Browse the repository at this point in the history
  • Loading branch information
polyaxon-ci committed Sep 28, 2023
1 parent ea6cc04 commit 5b1c74f
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 33 deletions.
8 changes: 6 additions & 2 deletions traceml/tests/test_tracking/test_run_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,9 @@ def test_log_mpl_plotly(self):
self.run_path, kind=V1ArtifactKind.CHART, name="figure"
)
assert os.path.exists(events_file) is True
results = V1Events.read(kind="image", name="figure", data=events_file)
results = V1Events.read(
kind=V1ArtifactKind.CHART, name="figure", data=events_file
)
assert len(results.df.values) == 1

with patch("traceml.tracking.run.Run._log_has_events") as log_mpl_plotly_chart:
Expand All @@ -842,7 +844,9 @@ def test_log_mpl_plotly(self):
self.run_path, kind=V1ArtifactKind.CHART, name="figure"
)
assert os.path.exists(events_file) is True
results = V1Events.read(kind="image", name="figure", data=events_file)
results = V1Events.read(
kind=V1ArtifactKind.CHART, name="figure", data=events_file
)
assert len(results.df.values) == 2

def test_log_video_from_path(self):
Expand Down
8 changes: 8 additions & 0 deletions traceml/traceml/artifacts/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ class V1ArtifactKind(str, PEnum):
SYSTEM = "system"
ARTIFACT = "artifact"

@classmethod
def is_jsonl_file_event(cls, kind: Optional[Union["V1ArtifactKind", str]]) -> bool:
return kind in {
V1ArtifactKind.HTML,
V1ArtifactKind.TEXT,
V1ArtifactKind.CHART,
}

@classmethod
def is_single_file_event(cls, kind: Optional[Union["V1ArtifactKind", str]]) -> bool:
return kind in {
Expand Down
84 changes: 67 additions & 17 deletions traceml/traceml/events/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,58 @@ def __init__(self, kind, name, df):
self.name = name
self.df = df

@staticmethod
def _read_csv(
data: str,
parse_dates: Optional[bool] = True,
engine: Optional[str] = None,
) -> "pandas.DataFrame":
import pandas as pd

if parse_dates:
return pd.read_csv(
data,
sep=V1Event._SEPARATOR,
parse_dates=["timestamp"],
engine=engine,
)
else:
df = pd.read_csv(
data,
sep=V1Event._SEPARATOR,
engine=engine,
)
# Pyarrow automatically converts timestamp fields
if "timestamp" in df.columns:
df["timestamp"] = df["timestamp"].astype(str)
return df

@staticmethod
def _read_jsonl(
data: str,
parse_dates: Optional[bool] = True,
engine: Optional[str] = None,
) -> "pandas.DataFrame":
import pandas as pd

engine = engine or "ujson"
if parse_dates:
return pd.read_json(
data,
lines=True,
engine=engine,
)
else:
df = pd.read_json(
data,
lines=True,
engine=engine,
)
# Pyarrow automatically converts timestamp fields
if "timestamp" in df.columns:
df["timestamp"] = df["timestamp"].astype(str)
return df

@classmethod
def read(
cls,
Expand All @@ -379,23 +431,17 @@ def read(
import pandas as pd

if isinstance(data, str):
csv = validate_csv(data)
if parse_dates:
df = pd.read_csv(
csv,
sep=V1Event._SEPARATOR,
parse_dates=["timestamp"],
engine=engine,
)
else:
df = pd.read_csv(
csv,
sep=V1Event._SEPARATOR,
engine=engine,
)
# Pyarrow automatically converts timestamp fields
if "timestamp" in df.columns:
df["timestamp"] = df["timestamp"].astype(str)
data = validate_csv(data)
error = None
if V1ArtifactKind.is_jsonl_file_event(kind):
try:
df = cls._read_jsonl(
data=data, parse_dates=parse_dates, engine=engine
)
except Exception as e:
error = e
if not V1ArtifactKind.is_jsonl_file_event(kind) or error:
df = cls._read_csv(data=data, parse_dates=parse_dates, engine=engine)
elif isinstance(data, dict):
df = pd.DataFrame.from_dict(data)
else:
Expand Down Expand Up @@ -469,6 +515,10 @@ def get_csv_events(self) -> str:
events = ["\n{}".format(e.to_csv()) for e in self.events]
return "".join(events)

def get_jsonl_events(self) -> str:
events = ["\n{}".format(e.to_json()) for e in self.events]
return "".join(events)

def empty_events(self):
self.events[:] = []

Expand Down
37 changes: 23 additions & 14 deletions traceml/traceml/serialization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@
from clipped.utils.enums import get_enum_value
from clipped.utils.paths import check_or_create_path

from traceml.events import LoggedEventSpec
from traceml.events.schemas import LoggedEventListSpec
from traceml.artifacts import V1ArtifactKind
from traceml.events import (
LoggedEventListSpec,
LoggedEventSpec,
get_event_path,
get_resource_path,
)


class EventWriter:
Expand All @@ -21,18 +26,16 @@ def __init__(self, run_path: str, backend: str):

def _get_event_path(self, kind: str, name: str) -> str:
if self._events_backend == self.EVENTS_BACKEND:
return os.path.join(
self._run_path,
get_enum_value(self._events_backend),
kind,
"{}.plx".format(name),
return get_event_path(
run_path=self._run_path,
kind=kind,
name=name,
)
if self._events_backend == self.RESOURCES_BACKEND:
return os.path.join(
self._run_path,
get_enum_value(self._events_backend),
kind,
"{}.plx".format(name),
return get_resource_path(
run_path=self._run_path,
kind=kind,
name=name,
)
raise ValueError(
"Unrecognized backend {}".format(get_enum_value(self._events_backend))
Expand All @@ -44,12 +47,18 @@ def _init_events(self, events_spec: LoggedEventListSpec):
if not os.path.exists(event_path):
check_or_create_path(event_path, is_dir=False)
with open(event_path, "w") as event_file:
event_file.write(events_spec.get_csv_header())
if V1ArtifactKind.is_jsonl_file_event(events_spec.kind):
event_file.write("")
else:
event_file.write(events_spec.get_csv_header())

def _append_events(self, events_spec: LoggedEventListSpec):
event_path = self._get_event_path(kind=events_spec.kind, name=events_spec.name)
with open(event_path, "a") as event_file:
event_file.write(events_spec.get_csv_events())
if V1ArtifactKind.is_jsonl_file_event(events_spec.kind):
event_file.write(events_spec.get_jsonl_events())
else:
event_file.write(events_spec.get_csv_events())

def _events_to_files(self, events: List[LoggedEventSpec]):
for event in events:
Expand Down

0 comments on commit 5b1c74f

Please sign in to comment.