Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate attributes in update methods #134

Merged
merged 8 commits into from
Nov 20, 2023
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 50 additions & 29 deletions src/segments/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,16 +474,16 @@ def add_dataset(
"categories": [{"id": 1, "name": "object"}],
}

if type(task_attributes) is dict:
if isinstance(task_attributes, TaskAttributes):
task_attributes = task_attributes.model_dump()
else:
try:
TaskAttributes.model_validate(task_attributes)
except pydantic.ValidationError as e:
logger.error(
"Did you use the right task attributes? Please refer to the online documentation: https://docs.segments.ai/reference/categories-and-task-attributes#object-attribute-format.",
)
raise ValidationError(message=str(e), cause=e)
elif type(task_attributes) is TaskAttributes:
task_attributes = task_attributes.model_dump()

payload: Dict[str, Any] = {
"name": name,
Expand Down Expand Up @@ -583,11 +583,18 @@ def update_dataset(
payload["task_type"] = task_type

if task_attributes is not None:
payload["task_attributes"] = (
task_attributes.model_dump()
if type(task_attributes) is TaskAttributes
else task_attributes
)
if isinstance(task_attributes, TaskAttributes):
task_attributes = task_attributes.model_dump()
else:
try:
TaskAttributes.model_validate(task_attributes)
except pydantic.ValidationError as e:
logger.error(
"Did you use the right task attributes? Please refer to the online documentation: https://docs.segments.ai/reference/categories-and-task-attributes#object-attribute-format.",
)
raise ValidationError(message=str(e), cause=e)

payload["task_attributes"] = task_attributes

if category is not None:
payload["category"] = category
Expand Down Expand Up @@ -1023,16 +1030,16 @@ def add_sample(
:exc:`~segments.exceptions.TimeoutError`: If the request times out.
"""

if type(attributes) is dict:
if isinstance(attributes, get_args(SampleAttributes)):
attributes = attributes.model_dump()
else:
try:
TypeAdapter(SampleAttributes).validate_python(attributes)
except pydantic.ValidationError as e:
logger.error(
"Did you use the right sample attributes? Please refer to the online documentation: https://docs.segments.ai/reference/sample-and-label-types/sample-types.",
)
raise ValidationError(message=str(e), cause=e)
elif type(attributes) in get_args(SampleAttributes):
attributes = attributes.model_dump()

payload: Dict[str, Any] = {
"name": name,
Expand Down Expand Up @@ -1083,9 +1090,10 @@ def add_samples(
:exc:`~segments.exceptions.TimeoutError`: If the request times out.
"""

# Check the input
for sample in samples:
if type(sample) is dict:
if isinstance(sample, Sample):
sample = sample.model_dump()
else:
if "name" not in sample or "attributes" not in sample:
raise KeyError(
f"Please add a name and attributes to your sample: {sample}"
Expand All @@ -1098,8 +1106,6 @@ def add_samples(
"Did you use the right sample attributes? Please refer to the online documentation: https://docs.segments.ai/reference/sample-and-label-types/sample-types.",
)
raise ValidationError(message=str(e), cause=e)
elif type(sample) is Sample:
sample = sample.model_dump()

payload = samples

Expand Down Expand Up @@ -1148,7 +1154,7 @@ def update_sample(

Raises:
:exc:`~segments.exceptions.APILimitError`: If the API limit is exceeded.
:exc:`~segments.exceptions.ValidationError`: If validation of the samples fails.
:exc:`~segments.exceptions.ValidationError`: If validation of the sample fails.
:exc:`~segments.exceptions.NotFoundError`: If the sample is not found.
:exc:`~segments.exceptions.NetworkError`: If the request is not valid or if the server experienced an error.
:exc:`~segments.exceptions.TimeoutError`: If the request times out.
Expand All @@ -1160,11 +1166,18 @@ def update_sample(
payload["name"] = name

if attributes is not None:
payload["attributes"] = (
attributes.model_dump()
if type(attributes) in get_args(SampleAttributes)
else attributes
)
if isinstance(attributes, get_args(SampleAttributes)):
attributes = attributes.model_dump()
else:
try:
TypeAdapter(SampleAttributes).validate_python(attributes)
except pydantic.ValidationError as e:
logger.error(
"Did you use the right sample attributes? Please refer to the online documentation: https://docs.segments.ai/reference/sample-and-label-types/sample-types.",
)
raise ValidationError(message=str(e), cause=e)

payload["attributes"] = attributes

if metadata is not None:
payload["metadata"] = metadata
Expand Down Expand Up @@ -1287,16 +1300,16 @@ def add_label(
:exc:`~segments.exceptions.TimeoutError`: If the request times out.
"""

if type(attributes) is dict:
if isinstance(attributes, get_args(LabelAttributes)):
attributes = attributes.model_dump()
else:
try:
TypeAdapter(LabelAttributes).validate_python(attributes)
except pydantic.ValidationError as e:
logger.error(
"Did you use the right label attributes? Please refer to the online documentation: https://docs.segments.ai/reference/sample-and-label-types/label-types.",
)
raise ValidationError(message=str(e), cause=e)
elif type(attributes) in get_args(LabelAttributes):
attributes = attributes.model_dump()

payload: Dict[str, Any] = {
"label_status": label_status,
Expand Down Expand Up @@ -1357,11 +1370,18 @@ def update_label(
payload: Dict[str, Any] = {}

if attributes is not None:
payload["attributes"] = (
attributes.model_dump()
if type(attributes) in get_args(LabelAttributes)
else attributes
)
if isinstance(attributes, get_args(LabelAttributes)):
attributes = attributes.model_dump()
else:
try:
TypeAdapter(LabelAttributes).validate_python(attributes)
except pydantic.ValidationError as e:
logger.error(
"Did you use the right label attributes? Please refer to the online documentation: https://docs.segments.ai/reference/sample-and-label-types/label-types.",
)
raise ValidationError(message=str(e), cause=e)

payload["attributes"] = attributes

if label_status is not None:
payload["label_status"] = label_status
Expand Down Expand Up @@ -1477,6 +1497,7 @@ def add_labelset(
:exc:`~segments.exceptions.NetworkError`: If the request is not valid or if the server experienced an error.
:exc:`~segments.exceptions.TimeoutError`: If the request times out.
"""

payload = {
"name": name,
"description": description,
Expand Down