Skip to content

Commit

Permalink
add datetime hashing for st.cache_data and st.cache_resource (streaml…
Browse files Browse the repository at this point in the history
…it#6812)

Previously hashing of datetime objects happens via calling __reduce__, but that fails when datetime object was aware (containing information about the timezone).

This changes adding special hashing handling for DateTime objects, by converting them to isoformat(which contains information about timezone offset).

Please note that via that we lose information about the exact time zone in hashing, but information about UTC offset will be preserved. So e.g. same date-times with different timezones, but the same time stamp and offset will be hashed to the same value. Which should be fine in the vast majority of cases.
  • Loading branch information
kajarenc authored and Your Name committed Mar 22, 2024
1 parent 819367c commit 010d7fd
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 1 deletion.
4 changes: 4 additions & 0 deletions lib/streamlit/runtime/caching/hashing.py
Expand Up @@ -15,6 +15,7 @@
"""Hashing for st.cache_data and st.cache_resource."""
import collections
import dataclasses
import datetime
import functools
import hashlib
import inspect
Expand Down Expand Up @@ -371,6 +372,9 @@ def _to_bytes(self, obj: Any) -> bytes:
elif isinstance(obj, uuid.UUID):
return obj.bytes

elif isinstance(obj, datetime.datetime):
return obj.isoformat().encode()

elif isinstance(obj, (list, tuple)):
h = hashlib.new("md5")
for item in obj:
Expand Down
57 changes: 56 additions & 1 deletion lib/tests/streamlit/runtime/caching/hashing_test.py
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

"""st.memo/singleton hashing tests."""

import datetime
import functools
import hashlib
import os
Expand All @@ -29,8 +29,11 @@
from unittest.mock import MagicMock, Mock

import cffi
import dateutil.tz
import numpy as np
import pandas
import pandas as pd
import tzlocal
from parameterized import parameterized
from PIL import Image

Expand Down Expand Up @@ -89,6 +92,58 @@ def test_uuid(self):
self.assertNotEqual(id(uuid3), id(uuid3_copy))
self.assertNotEqual(get_hash(uuid3), get_hash(uuid4))

def test_datetime_naive(self):
naive_datetime1 = datetime.datetime(2007, 12, 23, 15, 45, 55)
naive_datetime1_copy = datetime.datetime(2007, 12, 23, 15, 45, 55)
naive_datetime3 = datetime.datetime(2011, 12, 21, 15, 45, 55)

self.assertEqual(get_hash(naive_datetime1), get_hash(naive_datetime1_copy))
self.assertNotEqual(id(naive_datetime1), id(naive_datetime1_copy))
self.assertNotEqual(get_hash(naive_datetime1), get_hash(naive_datetime3))

@parameterized.expand(
[
datetime.timezone.utc,
tzlocal.get_localzone(),
dateutil.tz.gettz("America/Los_Angeles"),
dateutil.tz.gettz("Europe/Berlin"),
dateutil.tz.UTC,
]
)
def test_datetime_aware(self, tz_info):
aware_datetime1 = datetime.datetime(2007, 12, 23, 15, 45, 55, tzinfo=tz_info)
aware_datetime1_copy = datetime.datetime(
2007, 12, 23, 15, 45, 55, tzinfo=tz_info
)
aware_datetime2 = datetime.datetime(2011, 12, 21, 15, 45, 55, tzinfo=tz_info)

# naive datetime1 is the same datetime that aware_datetime,
# but without timezone info. They should have different hashes.
naive_datetime1 = datetime.datetime(2007, 12, 23, 15, 45, 55)

self.assertEqual(get_hash(aware_datetime1), get_hash(aware_datetime1_copy))
self.assertNotEqual(id(aware_datetime1), id(aware_datetime1_copy))
self.assertNotEqual(get_hash(aware_datetime1), get_hash(aware_datetime2))
self.assertNotEqual(get_hash(aware_datetime1), get_hash(naive_datetime1))

@parameterized.expand(
[
"US/Pacific",
"America/Los_Angeles",
"Europe/Berlin",
"UTC",
None, # check for naive too
]
)
def test_pandas_timestamp(self, tz_info):
timestamp1 = pandas.Timestamp("2017-01-01T12", tz=tz_info)
timestamp1_copy = pandas.Timestamp("2017-01-01T12", tz=tz_info)
timestamp2 = pandas.Timestamp("2019-01-01T12", tz=tz_info)

self.assertEqual(get_hash(timestamp1), get_hash(timestamp1_copy))
self.assertNotEqual(id(timestamp1), id(timestamp1_copy))
self.assertNotEqual(get_hash(timestamp1), get_hash(timestamp2))

def test_mocks_do_not_result_in_infinite_recursion(self):
try:
get_hash(Mock())
Expand Down

0 comments on commit 010d7fd

Please sign in to comment.