Skip to content

Commit

Permalink
Transaction updates (#1464)
Browse files Browse the repository at this point in the history
* More transaction integration tests
* support for distributed transactions
* Transaction context manager

- [ ] I have reviewed the [Guidelines for Contributing](CONTRIBUTING.md)
and the [Code of Conduct](CODE_OF_CONDUCT.md).
  • Loading branch information
richard-rogers committed Feb 20, 2024
1 parent 926af60 commit 06b15fd
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 9 deletions.
137 changes: 136 additions & 1 deletion python/tests/api/writer/test_whylabs_integration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
import os
import time
from uuid import uuid4

import numpy as np
import pandas as pd
import pytest
from whylabs_client.api.dataset_profile_api import DatasetProfileApi
Expand All @@ -13,7 +15,7 @@
)

import whylogs as why
from whylogs.api.writer.whylabs import WhyLabsWriter
from whylogs.api.writer.whylabs import WhyLabsTransaction, WhyLabsWriter
from whylogs.core import DatasetProfileView
from whylogs.core.feature_weights import FeatureWeights
from whylogs.core.schema import DatasetSchema
Expand All @@ -28,6 +30,8 @@

SLEEP_TIME = 30

logger = logging.getLogger(__name__)


@pytest.mark.load
def test_whylabs_writer():
Expand Down Expand Up @@ -212,3 +216,134 @@ def test_transactions():
downloaded_profile = writer._s3_pool.request("GET", download_url, headers=headers, timeout=writer._timeout_seconds)
deserialized_view = DatasetProfileView.deserialize(downloaded_profile.data)
assert deserialized_view.get_columns().keys() == data.keys()


@pytest.mark.load
def test_transaction_context():
ORG_ID = os.environ.get("WHYLABS_DEFAULT_ORG_ID")
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
why.init(force_local=True)
schema = DatasetSchema()
csv_url = "https://whylabs-public.s3.us-west-2.amazonaws.com/datasets/tour/current.csv"
df = pd.read_csv(csv_url)
pdfs = np.array_split(df, 7)
writer = WhyLabsWriter(dataset_id=MODEL_ID)
tids = list()
try:
with WhyLabsTransaction(writer):
for data in pdfs:
trace_id = str(uuid4())
tids.append(trace_id)
result = why.log(data, schema=schema, trace_id=trace_id)
status, id = writer.write(result.profile())
if not status:
raise Exception() # or retry the profile...

except Exception:
# The start_transaction() or commit_transaction() in the
# WhyLabsTransaction context manager will throw on failure.
# Or retry the commit
logger.exception("Logging transaction failed")

time.sleep(SLEEP_TIME) # platform needs time to become aware of the profile
dataset_api = DatasetProfileApi(writer._api_client)
for trace_id in tids:
response: ProfileTracesResponse = dataset_api.get_profile_traces(
org_id=ORG_ID,
dataset_id=MODEL_ID,
trace_id=trace_id,
)
download_url = response.get("traces")[0]["download_url"]
headers = {"Content-Type": "application/octet-stream"}
downloaded_profile = writer._s3_pool.request(
"GET", download_url, headers=headers, timeout=writer._timeout_seconds
)
deserialized_view = DatasetProfileView.deserialize(downloaded_profile.data)
assert deserialized_view is not None


@pytest.mark.load
def test_transaction_segmented():
ORG_ID = os.environ.get("WHYLABS_DEFAULT_ORG_ID")
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
why.init(force_local=True)
schema = DatasetSchema(segments=segment_on_column("Gender"))
csv_url = "https://whylabs-public.s3.us-west-2.amazonaws.com/datasets/tour/current.csv"
data = pd.read_csv(csv_url)
writer = WhyLabsWriter(dataset_id=MODEL_ID)
trace_id = str(uuid4())
try:
writer.start_transaction()
result = why.log(data, schema=schema, trace_id=trace_id)
status, id = writer.write(result)
if not status:
raise Exception() # or retry the profile...

except Exception:
# The start_transaction() or commit_transaction() in the
# WhyLabsTransaction context manager will throw on failure.
# Or retry the commit
logger.exception("Logging transaction failed")

writer.commit_transaction()
time.sleep(SLEEP_TIME) # platform needs time to become aware of the profile
dataset_api = DatasetProfileApi(writer._api_client)
response: ProfileTracesResponse = dataset_api.get_profile_traces(
org_id=ORG_ID,
dataset_id=MODEL_ID,
trace_id=trace_id,
)
assert len(response.get("traces")) == 2
for trace in response.get("traces"):
download_url = trace.get("download_url")
headers = {"Content-Type": "application/octet-stream"}
downloaded_profile = writer._s3_pool.request(
"GET", download_url, headers=headers, timeout=writer._timeout_seconds
)
assert downloaded_profile is not None


@pytest.mark.load
def test_transaction_distributed():
ORG_ID = os.environ.get("WHYLABS_DEFAULT_ORG_ID")
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
why.init(force_local=True)
schema = DatasetSchema()
csv_url = "https://whylabs-public.s3.us-west-2.amazonaws.com/datasets/tour/current.csv"
df = pd.read_csv(csv_url)
pdfs = np.array_split(df, 7)
writer = WhyLabsWriter(dataset_id=MODEL_ID)
tids = list()
try:
transaction_id = writer.start_transaction()
for data in pdfs: # pretend each iteration is run on a different machine
dist_writer = WhyLabsWriter(dataset_id=MODEL_ID)
dist_writer.start_transaction(transaction_id)
trace_id = str(uuid4())
tids.append(trace_id)
result = why.log(data, schema=schema, trace_id=trace_id)
status, id = dist_writer.write(result.profile())
if not status:
raise Exception() # or retry the profile...
writer.commit_transaction()
except Exception:
# The start_transaction() or commit_transaction() in the
# WhyLabsTransaction context manager will throw on failure.
# Or retry the commit
logger.exception("Logging transaction failed")

time.sleep(SLEEP_TIME) # platform needs time to become aware of the profile
dataset_api = DatasetProfileApi(writer._api_client)
for trace_id in tids:
response: ProfileTracesResponse = dataset_api.get_profile_traces(
org_id=ORG_ID,
dataset_id=MODEL_ID,
trace_id=trace_id,
)
download_url = response.get("traces")[0]["download_url"]
headers = {"Content-Type": "application/octet-stream"}
downloaded_profile = writer._s3_pool.request(
"GET", download_url, headers=headers, timeout=writer._timeout_seconds
)
deserialized_view = DatasetProfileView.deserialize(downloaded_profile.data)
assert deserialized_view is not None
97 changes: 89 additions & 8 deletions python/whylogs/api/writer/whylabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def _check_whylabs_condition_count_uncompound() -> bool:
else:
logger.info(f"Got response code {response.status_code} but expected 200, so running uncompound")
except Exception:
logger.warning("Error trying to read whylabs config, falling back to defaults for uncompounding")
pass

_WHYLABS_SKIP_CONFIG_READ = True
return True

Expand Down Expand Up @@ -573,6 +574,65 @@ def _write_segmented_reference_result_set(self, file: SegmentedResultSet, **kwar
else:
return False, "Failed to upload all segments"

def _flatten_tags(self, tags: Union[List, Dict]) -> List[SegmentTag]:
if type(tags[0]) == list:
result: List[SegmentTag] = []
for t in tags:
result.append(self._flatten_tags(t))
return result

return [SegmentTag(t["key"], t["value"]) for t in tags]

def _write_segmented_result_set_transaction(self, file: SegmentedResultSet, **kwargs: Any) -> Tuple[bool, str]:
utc_now = datetime.datetime.now(datetime.timezone.utc)

files = file.get_writables()
partitions = file.partitions
if len(partitions) > 1:
logger.warning(
"SegmentedResultSet contains more than one partition. Only the first partition will be uploaded. "
)
partition = partitions[0]
whylabs_tags = list()
for view in files:
view_tags = list()
dataset_timestamp = view.dataset_timestamp or utc_now
if view.partition.id != partition.id:
continue
_, segment_tags, _ = _generate_segment_tags_metadata(view.segment, view.partition)
for segment_tag in segment_tags:
tag_key = segment_tag.key.replace("whylogs.tag.", "")
tag_value = segment_tag.value
view_tags.append({"key": tag_key, "value": tag_value})
whylabs_tags.append(view_tags)
stamp = dataset_timestamp.timestamp()
dataset_timestamp_epoch = int(stamp * 1000)

region = os.getenv("WHYLABS_UPLOAD_REGION", None)
client: TransactionsApi = self._get_or_create_transaction_client()
messages: List[str] = list()
and_status: bool = True
for view, tags in zip(files, self._flatten_tags(whylabs_tags)):
with tempfile.NamedTemporaryFile() as tmp_file:
view.write(file=tmp_file)
tmp_file.flush()
tmp_file.seek(0)
request = TransactionLogRequest(
dataset_timestamp=dataset_timestamp_epoch, segment_tags=tags, region=region
)
result: AsyncLogResponse = client.log_transaction(self._transaction_id, request, **kwargs)
logger.info(f"Added profile {result.id} to transaction {self._transaction_id}")
bool_status, message = self._do_upload(
dataset_timestamp=dataset_timestamp_epoch,
upload_url=result.upload_url,
profile_id=result.id,
profile_file=tmp_file,
)
and_status = and_status and bool_status
messages.append(message)

return and_status, "; ".join(messages)

def _write_segmented_result_set(self, file: SegmentedResultSet, **kwargs: Any) -> Tuple[bool, str]:
"""Put segmented result set for the specified dataset.
Expand All @@ -585,6 +645,9 @@ def _write_segmented_result_set(self, file: SegmentedResultSet, **kwargs: Any) -
-------
Tuple[bool, str]
"""
if self._transaction_id is not None:
return self._write_segmented_result_set_transaction(file, **kwargs)

# multi-profile writer
files = file.get_writables()
messages: List[str] = list()
Expand Down Expand Up @@ -617,39 +680,46 @@ def _get_or_create_transaction_client(self) -> TransactionsApi:
self._refresh_client()
return TransactionsApi(self._api_client)

def start_transaction(self, **kwargs) -> None:
def start_transaction(self, transaction_id: Optional[str] = None, **kwargs) -> str:
"""
Initiates a transaction -- any profiles subsequently written by calling write()
will be uploaded to WhyLabs atomically when commit_transaction() is called. Throws
will be uploaded to WhyLabs, but not ingested until commit_transaction() is called. Throws
on failure.
"""
if self._transaction_id is not None:
logger.error("Must end current transaction with commit_transaction() before starting another")
return
return self._transaction_id

if kwargs.get("dataset_id") is not None:
self._dataset_id = kwargs.get("dataset_id")

if transaction_id is not None:
self._transaction_id = transaction_id # type: ignore
return transaction_id

client: TransactionsApi = self._get_or_create_transaction_client()
request = TransactionStartRequest(dataset_id=self._dataset_id)
result: LogTransactionMetadata = client.start_transaction(request, **kwargs)
self._transaction_id = result["transaction_id"]
logger.info(f"Starting transaction {self._transaction_id}, expires {result['expiration_time']}")
return self._transaction_id # type: ignore

def commit_transaction(self, **kwargs) -> None:
"""
Atomically upload any profiles written since the previous start_transaction().
Ingest any profiles written since the previous start_transaction().
Throws on failure.
"""
if self._transaction_id is None:
logger.error("Must call start_transaction() before commit_transaction()")
return

logger.info(f"Committing transaction {self._transaction_id}")
id = self._transaction_id
self._transaction_id = None
logger.info(f"Committing transaction {id}")
client = self._get_or_create_transaction_client()
request = TransactionCommitRequest(verbose=True)
client.commit_transaction(self._transaction_id, request, **kwargs)
self._transaction_id = None
# We abandon the transaction if this throws
client.commit_transaction(id, request, **kwargs)

@deprecated_alias(profile="file")
def write(self, file: Writable, **kwargs: Any) -> Tuple[bool, str]:
Expand Down Expand Up @@ -1079,3 +1149,14 @@ def _get_upload_url(self, dataset_timestamp: int) -> Tuple[str, str]:
logger.debug(f"Replaced URL with our private domain. New URL: {upload_url}")

return upload_url, profile_id


class WhyLabsTransaction:
def __init__(self, writer: WhyLabsWriter):
self._writer = writer

def __enter__(self) -> None:
self._writer.start_transaction()

def __exit__(self, exc_type, exc_value, exc_tb) -> None:
self._writer.commit_transaction()

0 comments on commit 06b15fd

Please sign in to comment.