Skip to content

Commit

Permalink
Add ability to use timedelta and stings to start_time and `end_time…
Browse files Browse the repository at this point in the history
…`. (#8348)

Add the ability to use time delta and stings to start_time and end_time.

Also moved ttl_to_seconds function out of cache_utils to runtime_utils and renamed it to duration_to_seconds, since now it is also used outside of the cache module.
  • Loading branch information
kajarenc committed Mar 26, 2024
1 parent 6f4f0ee commit 8121928
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 75 deletions.
56 changes: 52 additions & 4 deletions lib/streamlit/elements/media.py
Expand Up @@ -16,6 +16,7 @@

import io
import re
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Final, Union, cast

Expand All @@ -29,6 +30,7 @@
from streamlit.proto.Video_pb2 import Video as VideoProto
from streamlit.runtime import caching
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.runtime_util import duration_to_seconds

if TYPE_CHECKING:
from typing import Any
Expand All @@ -45,17 +47,27 @@
str, Path, bytes, io.BytesIO, Dict[str, Union[str, Path, bytes, io.BytesIO]], None
]

MediaTime: TypeAlias = Union[int, float, timedelta, str]

TIMEDELTA_PARSE_ERROR_MESSAGE: Final = (
"Failed to convert '{param_name}' to a timedelta. "
"Please use a string in a format supported by "
"[Pandas Timedelta constructor]"
"(https://pandas.pydata.org/docs/reference/api/pandas.Timedelta.html), "
'e.g. `"10s"`, `"15 seconds"`, or `"1h23s"`. Got: {param_value}'
)


class MediaMixin:
@gather_metrics("audio")
def audio(
self,
data: MediaData,
format: str = "audio/wav",
start_time: int = 0,
start_time: MediaTime = 0,
*,
sample_rate: int | None = None,
end_time: int | None = None,
end_time: MediaTime | None = None,
loop: bool = False,
) -> DeltaGenerator:
"""Display an audio player.
Expand Down Expand Up @@ -111,6 +123,8 @@ def audio(
height: 865px
"""
start_time, end_time = _parse_start_time_end_time(start_time, end_time)

audio_proto = AudioProto()
coordinates = self.dg._get_delta_path_str()

Expand Down Expand Up @@ -143,10 +157,10 @@ def video(
self,
data: MediaData,
format: str = "video/mp4",
start_time: int = 0,
start_time: MediaTime = 0,
*, # keyword-only arguments:
subtitles: SubtitleData = None,
end_time: int | None = None,
end_time: MediaTime | None = None,
loop: bool = False,
) -> DeltaGenerator:
"""Display a video player.
Expand Down Expand Up @@ -246,6 +260,9 @@ def video(
for more information.
"""

start_time, end_time = _parse_start_time_end_time(start_time, end_time)

video_proto = VideoProto()
coordinates = self.dg._get_delta_path_str()
marshall_video(
Expand Down Expand Up @@ -461,6 +478,37 @@ def marshall_video(
) from original_err


def _parse_start_time_end_time(
start_time: MediaTime, end_time: MediaTime | None
) -> tuple[int, int | None]:
"""Parse start_time and end_time and return them as int."""

try:
maybe_start_time = duration_to_seconds(start_time, coerce_none_to_inf=False)
if maybe_start_time is None:
raise ValueError
start_time = int(maybe_start_time)
except (StreamlitAPIException, ValueError):
error_msg = TIMEDELTA_PARSE_ERROR_MESSAGE.format(
param_name="start_time", param_value=start_time
)
raise StreamlitAPIException(error_msg) from None

try:
# TODO[kajarenc]: Replace `duration_to_seconds` with `time_to_seconds`
# when PR #8343 is merged.
end_time = duration_to_seconds(end_time, coerce_none_to_inf=False)
if end_time is not None:
end_time = int(end_time)
except StreamlitAPIException:
error_msg = TIMEDELTA_PARSE_ERROR_MESSAGE.format(
param_name="end_time", param_value=end_time
)
raise StreamlitAPIException(error_msg) from None

return start_time, end_time


def _validate_and_normalize(data: npt.NDArray[Any]) -> tuple[bytes, int]:
"""Validates and normalizes numpy array data.
We validate numpy array shape (should be 1d or 2d)
Expand Down
6 changes: 3 additions & 3 deletions lib/streamlit/runtime/caching/cache_data_api.py
Expand Up @@ -35,7 +35,6 @@
Cache,
CachedFuncInfo,
make_cached_func_wrapper,
ttl_to_seconds,
)
from streamlit.runtime.caching.cached_message_replay import (
CachedMessageReplayContext,
Expand All @@ -59,6 +58,7 @@
MemoryCacheStorageManager,
)
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.runtime_util import duration_to_seconds
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats

Expand Down Expand Up @@ -154,7 +154,7 @@ def get_cache(
If it doesn't exist, create a new one with the given params.
"""

ttl_seconds = ttl_to_seconds(ttl, coerce_none_to_inf=False)
ttl_seconds = duration_to_seconds(ttl, coerce_none_to_inf=False)

# Get the existing cache, if it exists, and validate that its params
# haven't changed.
Expand Down Expand Up @@ -254,7 +254,7 @@ def validate_cache_params(
CacheStorageContext.
"""

ttl_seconds = ttl_to_seconds(ttl, coerce_none_to_inf=False)
ttl_seconds = duration_to_seconds(ttl, coerce_none_to_inf=False)

cache_context = self.create_cache_storage_context(
function_key="DUMMY_KEY",
Expand Down
11 changes: 0 additions & 11 deletions lib/streamlit/runtime/caching/cache_errors.py
Expand Up @@ -176,14 +176,3 @@ class UnevaluatedDataFrameError(StreamlitAPIException):
"""Used to display a message about uncollected dataframe being used"""

pass


class BadTTLStringError(StreamlitAPIException):
"""Raised when a bad ttl= argument string is passed."""

def __init__(self, ttl: str):
MarkdownFormattedException.__init__(
self,
"TTL string doesn't look right. It should be formatted as"
f"`'1d2h34m'` or `2 days`, for example. Got: {ttl}",
)
4 changes: 2 additions & 2 deletions lib/streamlit/runtime/caching/cache_resource_api.py
Expand Up @@ -35,7 +35,6 @@
Cache,
CachedFuncInfo,
make_cached_func_wrapper,
ttl_to_seconds,
)
from streamlit.runtime.caching.cached_message_replay import (
CachedMessageReplayContext,
Expand All @@ -46,6 +45,7 @@
)
from streamlit.runtime.caching.hashing import HashFuncsDict
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.runtime_util import duration_to_seconds
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats

Expand Down Expand Up @@ -89,7 +89,7 @@ def get_cache(
if max_entries is None:
max_entries = math.inf

ttl_seconds = ttl_to_seconds(ttl)
ttl_seconds = duration_to_seconds(ttl)

# Get the existing cache, if it exists, and validate that its params
# haven't changed.
Expand Down
44 changes: 1 addition & 43 deletions lib/streamlit/runtime/caching/cache_utils.py
Expand Up @@ -19,20 +19,17 @@
import functools
import hashlib
import inspect
import math
import threading
import time
import types
from abc import abstractmethod
from collections import defaultdict
from datetime import timedelta
from typing import Any, Callable, Final, Literal, overload
from typing import Any, Callable, Final

from streamlit import type_util
from streamlit.elements.spinner import spinner
from streamlit.logger import get_logger
from streamlit.runtime.caching.cache_errors import (
BadTTLStringError,
CacheError,
CacheKeyNotFoundError,
UnevaluatedDataFrameError,
Expand All @@ -58,45 +55,6 @@
TTLCACHE_TIMER = time.monotonic


@overload
def ttl_to_seconds(
ttl: float | timedelta | str | None, *, coerce_none_to_inf: Literal[False]
) -> float | None:
...


@overload
def ttl_to_seconds(ttl: float | timedelta | str | None) -> float:
...


def ttl_to_seconds(
ttl: float | timedelta | str | None, *, coerce_none_to_inf: bool = True
) -> float | None:
"""
Convert a ttl value to a float representing "number of seconds".
"""
if coerce_none_to_inf and ttl is None:
return math.inf
if isinstance(ttl, timedelta):
return ttl.total_seconds()
if isinstance(ttl, str):
import numpy as np
import pandas as pd

try:
out: float = pd.Timedelta(ttl).total_seconds()
except ValueError as ex:
raise BadTTLStringError(ttl) from ex

if np.isnan(out):
raise BadTTLStringError(ttl)

return out

return ttl


# We show a special "UnevaluatedDataFrame" warning for cached funcs
# that attempt to return one of these unserializable types:
UNEVALUATED_DATAFRAME_TYPES = (
Expand Down
56 changes: 54 additions & 2 deletions lib/streamlit/runtime/runtime_util.py
Expand Up @@ -16,10 +16,12 @@

from __future__ import annotations

from typing import Any
import math
from datetime import timedelta
from typing import Any, Literal, overload

from streamlit import config
from streamlit.errors import MarkdownFormattedException
from streamlit.errors import MarkdownFormattedException, StreamlitAPIException
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime.forward_msg_cache import populate_hash_if_needed

Expand Down Expand Up @@ -54,6 +56,17 @@ def _get_message(self, failed_msg_str: Any) -> str:
)


class BadDurationStringError(StreamlitAPIException):
"""Raised when a bad duration argument string is passed."""

def __init__(self, duration: str):
MarkdownFormattedException.__init__(
self,
"TTL string doesn't look right. It should be formatted as"
f"`'1d2h34m'` or `2 days`, for example. Got: {duration}",
)


def is_cacheable_msg(msg: ForwardMsg) -> bool:
"""True if the given message qualifies for caching."""
if msg.WhichOneof("type") in {"ref_hash", "initialize"}:
Expand All @@ -62,6 +75,45 @@ def is_cacheable_msg(msg: ForwardMsg) -> bool:
return msg.ByteSize() >= int(config.get_option("global.minCachedMessageSize"))


@overload
def duration_to_seconds(
ttl: float | timedelta | str | None, *, coerce_none_to_inf: Literal[False]
) -> float | None:
...


@overload
def duration_to_seconds(ttl: float | timedelta | str | None) -> float:
...


def duration_to_seconds(
ttl: float | timedelta | str | None, *, coerce_none_to_inf: bool = True
) -> float | None:
"""
Convert a ttl value to a float representing "number of seconds".
"""
if coerce_none_to_inf and ttl is None:
return math.inf
if isinstance(ttl, timedelta):
return ttl.total_seconds()
if isinstance(ttl, str):
import numpy as np
import pandas as pd

try:
out: float = pd.Timedelta(ttl).total_seconds()
except ValueError as ex:
raise BadDurationStringError(ttl) from ex

if np.isnan(out):
raise BadDurationStringError(ttl)

return out

return ttl


def serialize_forward_msg(msg: ForwardMsg) -> bytes:
"""Serialize a ForwardMsg to send to a client.
Expand Down
41 changes: 40 additions & 1 deletion lib/tests/streamlit/elements/audio_test.py
Expand Up @@ -23,7 +23,10 @@
from scipy.io import wavfile

import streamlit as st
from streamlit.elements.media import _maybe_convert_to_wav_bytes
from streamlit.elements.media import (
_maybe_convert_to_wav_bytes,
_parse_start_time_end_time,
)
from streamlit.errors import StreamlitAPIException
from streamlit.proto.Alert_pb2 import Alert as AlertProto
from streamlit.runtime.media_file_storage import MediaFileStorageError
Expand Down Expand Up @@ -275,3 +278,39 @@ def test_st_audio_options(self):
self.assertEqual(el.audio.loop, True)
self.assertTrue(el.audio.url.startswith(MEDIA_ENDPOINT))
self.assertTrue(_calculate_file_id(fake_audio_data, "audio/mp3"), el.audio.url)

@parameterized.expand(
[
("1s", None, (1, None)),
("1m", None, (60, None)),
("1m2s", None, (62, None)),
(0, "1m", (0, 60)),
("1h2m3s", None, (3723, None)),
("1m2s", "1m10s", (62, 70)),
("10 seconds", "15 seconds", (10, 15)),
("3 minutes 10 seconds", "3 minutes 20 seconds", (190, 200)),
]
)
def test_parse_start_time_end_time_success(
self, input_start_time, input_end_time, expected_value
):
"""Test that _parse_start_time_end_time works correctly."""
self.assertEqual(
_parse_start_time_end_time(input_start_time, input_end_time),
expected_value,
)

@parameterized.expand(
[
("INVALID_VALUE", None, "Failed to convert 'start_time' to a timedelta"),
(5, "INVALID_VALUE", "Failed to convert 'end_time' to a timedelta"),
]
)
def test_parse_start_time_end_time_fail(self, start_time, end_time, exception_text):
"""Test that _parse_start_time_end_time works with correct exception text."""

with self.assertRaises(StreamlitAPIException) as e:
_parse_start_time_end_time(start_time, end_time)

self.assertIn(exception_text, str(e.exception))
self.assertIn("INVALID_VALUE", str(e.exception))
2 changes: 1 addition & 1 deletion lib/tests/streamlit/elements/help_test.py
Expand Up @@ -109,7 +109,7 @@ def test_deltagenerator_func(self):
self.assertEqual("st.audio", ds.name)
self.assertEqual("method", ds.type)

signature = "(data: 'MediaData', format: 'str' = 'audio/wav', start_time: 'int' = 0, *, sample_rate: 'int | None' = None, end_time: 'int | None' = None, loop: 'bool' = False) -> 'DeltaGenerator'"
signature = "(data: 'MediaData', format: 'str' = 'audio/wav', start_time: 'MediaTime' = 0, *, sample_rate: 'int | None' = None, end_time: 'MediaTime | None' = None, loop: 'bool' = False) -> 'DeltaGenerator'"

self.assertEqual(
f"streamlit.delta_generator.MediaMixin.audio{signature}", ds.value
Expand Down

0 comments on commit 8121928

Please sign in to comment.