Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to use timedelta and stings to start_time and end_time. #8348

Merged
merged 20 commits into from Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 50 additions & 4 deletions lib/streamlit/elements/media.py
Expand Up @@ -16,6 +16,7 @@

import io
import re
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Final, Union, cast

Expand All @@ -29,6 +30,7 @@
from streamlit.proto.Video_pb2 import Video as VideoProto
from streamlit.runtime import caching
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.runtime_util import duration_to_seconds

if TYPE_CHECKING:
from typing import Any
Expand All @@ -45,17 +47,27 @@
str, Path, bytes, io.BytesIO, Dict[str, Union[str, Path, bytes, io.BytesIO]], None
]

MediaTime: TypeAlias = Union[int, float, timedelta, str]

TIMEDELTA_PARSE_ERROR_MESSAGE: Final = (
"Failed to convert '{param_name}' to a timedelta. "
"Please use a string in a format supported by "
"[Pandas Timedelta constructor]"
"(https://pandas.pydata.org/docs/reference/api/pandas.Timedelta.html), "
'e.g. `"10s"`, `"15 seconds"`, or `"1h23s"`. Got: {param_value}'
)


class MediaMixin:
@gather_metrics("audio")
def audio(
self,
data: MediaData,
format: str = "audio/wav",
start_time: int = 0,
start_time: MediaTime = 0,
*,
sample_rate: int | None = None,
end_time: int | None = None,
end_time: MediaTime | None = None,
loop: bool = False,
) -> DeltaGenerator:
"""Display an audio player.
Expand Down Expand Up @@ -111,6 +123,8 @@ def audio(
height: 865px

"""
start_time, end_time = _parse_start_time_end_time(start_time, end_time)

audio_proto = AudioProto()
coordinates = self.dg._get_delta_path_str()

Expand Down Expand Up @@ -143,10 +157,10 @@ def video(
self,
data: MediaData,
format: str = "video/mp4",
start_time: int = 0,
start_time: MediaTime = 0,
*, # keyword-only arguments:
subtitles: SubtitleData = None,
end_time: int | None = None,
end_time: MediaTime | None = None,
loop: bool = False,
) -> DeltaGenerator:
"""Display a video player.
Expand Down Expand Up @@ -246,6 +260,9 @@ def video(
for more information.

"""

start_time, end_time = _parse_start_time_end_time(start_time, end_time)

video_proto = VideoProto()
coordinates = self.dg._get_delta_path_str()
marshall_video(
Expand Down Expand Up @@ -461,6 +478,35 @@ def marshall_video(
) from original_err


def _parse_start_time_end_time(
start_time: MediaTime, end_time: MediaTime | None
) -> tuple[int, int | None]:
"""Parse start_time and end_time and return them as int."""

try:
maybe_start_time = duration_to_seconds(start_time, coerce_none_to_inf=False)
if maybe_start_time is None:
raise ValueError
start_time = int(maybe_start_time)
except (StreamlitAPIException, ValueError):
error_msg = TIMEDELTA_PARSE_ERROR_MESSAGE.format(
param_name="start_time", param_value=start_time
)
raise StreamlitAPIException(error_msg) from None

try:
end_time = duration_to_seconds(end_time, coerce_none_to_inf=False)
if end_time is not None:
end_time = int(end_time)
except StreamlitAPIException:
error_msg = TIMEDELTA_PARSE_ERROR_MESSAGE.format(
param_name="end_time", param_value=end_time
)
raise StreamlitAPIException(error_msg) from None

return start_time, end_time


def _validate_and_normalize(data: npt.NDArray[Any]) -> tuple[bytes, int]:
"""Validates and normalizes numpy array data.
We validate numpy array shape (should be 1d or 2d)
Expand Down
6 changes: 3 additions & 3 deletions lib/streamlit/runtime/caching/cache_data_api.py
Expand Up @@ -35,7 +35,6 @@
Cache,
CachedFuncInfo,
make_cached_func_wrapper,
ttl_to_seconds,
)
from streamlit.runtime.caching.cached_message_replay import (
CachedMessageReplayContext,
Expand All @@ -59,6 +58,7 @@
MemoryCacheStorageManager,
)
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.runtime_util import duration_to_seconds
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats

Expand Down Expand Up @@ -154,7 +154,7 @@ def get_cache(
If it doesn't exist, create a new one with the given params.
"""

ttl_seconds = ttl_to_seconds(ttl, coerce_none_to_inf=False)
ttl_seconds = duration_to_seconds(ttl, coerce_none_to_inf=False)

# Get the existing cache, if it exists, and validate that its params
# haven't changed.
Expand Down Expand Up @@ -254,7 +254,7 @@ def validate_cache_params(
CacheStorageContext.
"""

ttl_seconds = ttl_to_seconds(ttl, coerce_none_to_inf=False)
ttl_seconds = duration_to_seconds(ttl, coerce_none_to_inf=False)

cache_context = self.create_cache_storage_context(
function_key="DUMMY_KEY",
Expand Down
4 changes: 2 additions & 2 deletions lib/streamlit/runtime/caching/cache_resource_api.py
Expand Up @@ -35,7 +35,6 @@
Cache,
CachedFuncInfo,
make_cached_func_wrapper,
ttl_to_seconds,
)
from streamlit.runtime.caching.cached_message_replay import (
CachedMessageReplayContext,
Expand All @@ -46,6 +45,7 @@
)
from streamlit.runtime.caching.hashing import HashFuncsDict
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.runtime_util import duration_to_seconds
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats

Expand Down Expand Up @@ -89,7 +89,7 @@ def get_cache(
if max_entries is None:
max_entries = math.inf

ttl_seconds = ttl_to_seconds(ttl)
ttl_seconds = duration_to_seconds(ttl)

# Get the existing cache, if it exists, and validate that its params
# haven't changed.
Expand Down
44 changes: 1 addition & 43 deletions lib/streamlit/runtime/caching/cache_utils.py
Expand Up @@ -19,20 +19,17 @@
import functools
import hashlib
import inspect
import math
import threading
import time
import types
from abc import abstractmethod
from collections import defaultdict
from datetime import timedelta
from typing import Any, Callable, Final, Literal, overload
from typing import Any, Callable, Final

from streamlit import type_util
from streamlit.elements.spinner import spinner
from streamlit.logger import get_logger
from streamlit.runtime.caching.cache_errors import (
BadTTLStringError,
CacheError,
CacheKeyNotFoundError,
UnevaluatedDataFrameError,
Expand All @@ -58,45 +55,6 @@
TTLCACHE_TIMER = time.monotonic


@overload
def ttl_to_seconds(
ttl: float | timedelta | str | None, *, coerce_none_to_inf: Literal[False]
) -> float | None:
...


@overload
def ttl_to_seconds(ttl: float | timedelta | str | None) -> float:
...


def ttl_to_seconds(
ttl: float | timedelta | str | None, *, coerce_none_to_inf: bool = True
) -> float | None:
"""
Convert a ttl value to a float representing "number of seconds".
"""
if coerce_none_to_inf and ttl is None:
return math.inf
if isinstance(ttl, timedelta):
return ttl.total_seconds()
if isinstance(ttl, str):
import numpy as np
import pandas as pd

try:
out: float = pd.Timedelta(ttl).total_seconds()
except ValueError as ex:
raise BadTTLStringError(ttl) from ex

if np.isnan(out):
raise BadTTLStringError(ttl)

return out

return ttl


# We show a special "UnevaluatedDataFrame" warning for cached funcs
# that attempt to return one of these unserializable types:
UNEVALUATED_DATAFRAME_TYPES = (
Expand Down
44 changes: 43 additions & 1 deletion lib/streamlit/runtime/runtime_util.py
Expand Up @@ -16,11 +16,14 @@

from __future__ import annotations

from typing import Any
import math
from datetime import timedelta
from typing import Any, Literal, overload

from streamlit import config
from streamlit.errors import MarkdownFormattedException
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime.caching.cache_errors import BadTTLStringError
from streamlit.runtime.forward_msg_cache import populate_hash_if_needed


Expand Down Expand Up @@ -62,6 +65,45 @@
return msg.ByteSize() >= int(config.get_option("global.minCachedMessageSize"))


@overload
def duration_to_seconds(
ttl: float | timedelta | str | None, *, coerce_none_to_inf: Literal[False]
) -> float | None:
...
Fixed Show fixed Hide fixed
Dismissed Show dismissed Hide dismissed


@overload
def duration_to_seconds(ttl: float | timedelta | str | None) -> float:
...
Fixed Show fixed Hide fixed
Dismissed Show dismissed Hide dismissed


def duration_to_seconds(
ttl: float | timedelta | str | None, *, coerce_none_to_inf: bool = True
) -> float | None:
"""
Convert a ttl value to a float representing "number of seconds".
Comment on lines +91 to +94
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should we rename ttl here to duration to better match the new function name?

Suggested change
ttl: float | timedelta | str | None, *, coerce_none_to_inf: bool = True
) -> float | None:
"""
Convert a ttl value to a float representing "number of seconds".
duration: float | timedelta | str | None, *, coerce_none_to_inf: bool = True
) -> float | None:
"""
Convert a duration value to a float representing "number of seconds".

"""
if coerce_none_to_inf and ttl is None:
return math.inf
if isinstance(ttl, timedelta):
return ttl.total_seconds()
if isinstance(ttl, str):
import numpy as np
import pandas as pd

try:
out: float = pd.Timedelta(ttl).total_seconds()
except ValueError as ex:
raise BadTTLStringError(ttl) from ex

if np.isnan(out):
raise BadTTLStringError(ttl)

return out

return ttl


def serialize_forward_msg(msg: ForwardMsg) -> bytes:
"""Serialize a ForwardMsg to send to a client.

Expand Down
40 changes: 39 additions & 1 deletion lib/tests/streamlit/elements/audio_test.py
Expand Up @@ -23,7 +23,10 @@
from scipy.io import wavfile

import streamlit as st
from streamlit.elements.media import _maybe_convert_to_wav_bytes
from streamlit.elements.media import (
_maybe_convert_to_wav_bytes,
_parse_start_time_end_time,
)
from streamlit.errors import StreamlitAPIException
from streamlit.proto.Alert_pb2 import Alert as AlertProto
from streamlit.runtime.media_file_storage import MediaFileStorageError
Expand Down Expand Up @@ -275,3 +278,38 @@ def test_st_audio_options(self):
self.assertEqual(el.audio.loop, True)
self.assertTrue(el.audio.url.startswith(MEDIA_ENDPOINT))
self.assertTrue(_calculate_file_id(fake_audio_data, "audio/mp3"), el.audio.url)

@parameterized.expand(
[
("1s", None, (1, None)),
("1m", None, (60, None)),
("1m2s", None, (62, None)),
(0, "1m", (0, 60)),
("1h2m3s", None, (3723, None)),
("10 seconds", "15 seconds", (10, 15)),
("3 minutes 10 seconds", "3 minutes 20 seconds", (190, 200)),
]
)
def test_parse_start_time_end_time_success(
self, input_start_time, input_end_time, expected_value
):
"""Test that _parse_start_time_end_time works correctly."""
self.assertEqual(
_parse_start_time_end_time(input_start_time, input_end_time),
expected_value,
)

@parameterized.expand(
[
("INVALID_VALUE", None, "Failed to convert 'start_time' to a timedelta"),
(5, "INVALID_VALUE", "Failed to convert 'end_time' to a timedelta"),
]
)
def test_time_delta_to_seconds_success(self, start_time, end_time, exception_text):
"""Test that _timedelta_to_seconds works with correct exception text."""

with self.assertRaises(StreamlitAPIException) as e:
_parse_start_time_end_time(start_time, end_time)

self.assertIn(exception_text, str(e.exception))
self.assertIn("INVALID_VALUE", str(e.exception))
2 changes: 1 addition & 1 deletion lib/tests/streamlit/elements/help_test.py
Expand Up @@ -109,7 +109,7 @@ def test_deltagenerator_func(self):
self.assertEqual("st.audio", ds.name)
self.assertEqual("method", ds.type)

signature = "(data: 'MediaData', format: 'str' = 'audio/wav', start_time: 'int' = 0, *, sample_rate: 'int | None' = None, end_time: 'int | None' = None, loop: 'bool' = False) -> 'DeltaGenerator'"
signature = "(data: 'MediaData', format: 'str' = 'audio/wav', start_time: 'MediaTime' = 0, *, sample_rate: 'int | None' = None, end_time: 'MediaTime | None' = None, loop: 'bool' = False) -> 'DeltaGenerator'"

self.assertEqual(
f"streamlit.delta_generator.MediaMixin.audio{signature}", ds.value
Expand Down
10 changes: 5 additions & 5 deletions lib/tests/streamlit/runtime/caching/cache_utils_test.py
Expand Up @@ -20,7 +20,7 @@
from parameterized import parameterized

from streamlit.runtime.caching.cache_errors import BadTTLStringError
from streamlit.runtime.caching.cache_utils import ttl_to_seconds
from streamlit.runtime.runtime_util import duration_to_seconds

NORMAL_PARAMS = [
("float", 3.5, 3.5),
Expand All @@ -45,7 +45,7 @@ class CacheUtilsTest(unittest.TestCase):
)
def test_ttl_to_seconds_coerced(self, _, input_value: Any, expected_seconds: float):
"""Test the various types of input that ttl_to_seconds accepts."""
self.assertEqual(expected_seconds, ttl_to_seconds(input_value))
self.assertEqual(expected_seconds, duration_to_seconds(input_value))

@parameterized.expand(
[
Expand All @@ -58,13 +58,13 @@ def test_ttl_to_seconds_not_coerced(
):
"""Test the various types of input that ttl_to_seconds accepts."""
self.assertEqual(
expected_seconds, ttl_to_seconds(input_value, coerce_none_to_inf=False)
expected_seconds, duration_to_seconds(input_value, coerce_none_to_inf=False)
)

def test_ttl_str_exception(self):
"""Test that a badly-formatted TTL string raises an exception."""
with self.assertRaises(BadTTLStringError):
ttl_to_seconds("")
duration_to_seconds("")

with self.assertRaises(BadTTLStringError):
ttl_to_seconds("1 flecond")
duration_to_seconds("1 flecond")