Skip to content

Commit

Permalink
Better dataframe hashing (streamlit#7331)
Browse files Browse the repository at this point in the history
Add column names to hash for DataFrame
Remove memoization heuristic for dataframes and numpy array, because they could be modified in place
  • Loading branch information
kajarenc authored and zyxue committed Apr 16, 2024
1 parent 2729165 commit 8f84357
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 19 deletions.
42 changes: 31 additions & 11 deletions lib/streamlit/runtime/caching/hashing.py
Expand Up @@ -252,13 +252,7 @@ def is_simple(obj):
if all(map(is_simple, obj)):
return ("__l", tuple(obj))

if (
type_util.is_type(obj, "pandas.core.frame.DataFrame")
or type_util.is_type(obj, "numpy.ndarray")
or inspect.isbuiltin(obj)
or inspect.isroutine(obj)
or inspect.iscode(obj)
):
if inspect.isbuiltin(obj) or inspect.isroutine(obj) or inspect.iscode(obj):
return id(obj)

return NoResult
Expand Down Expand Up @@ -402,15 +396,40 @@ def _to_bytes(self, obj: Any) -> bytes:
elif isinstance(obj, Enum):
return str(obj).encode()

elif type_util.is_type(obj, "pandas.core.frame.DataFrame") or type_util.is_type(
obj, "pandas.core.series.Series"
):
elif type_util.is_type(obj, "pandas.core.series.Series"):
import pandas as pd

h = hashlib.new("md5")
self.update(h, obj.size)
self.update(h, obj.dtype.name)

if len(obj) >= _PANDAS_ROWS_LARGE:
obj = obj.sample(n=_PANDAS_SAMPLE_SIZE, random_state=0)

try:
self.update(h, pd.util.hash_pandas_object(obj).values.tobytes())
return h.digest()
except TypeError:
# Use pickle if pandas cannot hash the object for example if
# it contains unhashable objects.
return b"%s" % pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)

elif type_util.is_type(obj, "pandas.core.frame.DataFrame"):
import pandas as pd

h = hashlib.new("md5")
self.update(h, obj.shape)

if len(obj) >= _PANDAS_ROWS_LARGE:
obj = obj.sample(n=_PANDAS_SAMPLE_SIZE, random_state=0)
try:
return b"%s" % pd.util.hash_pandas_object(obj).sum()
column_hash_bytes = self.to_bytes(
pd.util.hash_pandas_object(obj.dtypes)
)
self.update(h, column_hash_bytes)
values_hash_bytes = self.to_bytes(pd.util.hash_pandas_object(obj))
self.update(h, values_hash_bytes)
return h.digest()
except TypeError:
# Use pickle if pandas cannot hash the object for example if
# it contains unhashable objects.
Expand All @@ -419,6 +438,7 @@ def _to_bytes(self, obj: Any) -> bytes:
elif type_util.is_type(obj, "numpy.ndarray"):
h = hashlib.new("md5")
self.update(h, obj.shape)
self.update(h, str(obj.dtype))

if obj.size >= _NP_SIZE_LARGE:
import numpy as np
Expand Down
105 changes: 97 additions & 8 deletions lib/tests/streamlit/runtime/caching/hashing_test.py
Expand Up @@ -265,18 +265,91 @@ def test_regex(self):
self.assertEqual(get_hash(p1), get_hash(p2))
self.assertNotEqual(get_hash(p1), get_hash(p3))

def test_pandas_dataframe(self):
df1 = pd.DataFrame({"foo": [12]})
df2 = pd.DataFrame({"foo": [42]})
df3 = pd.DataFrame({"foo": [12]})
def test_pandas_large_dataframe(self):

df1 = pd.DataFrame(np.zeros((_PANDAS_ROWS_LARGE, 4)), columns=list("ABCD"))
df2 = pd.DataFrame(np.ones((_PANDAS_ROWS_LARGE, 4)), columns=list("ABCD"))
df3 = pd.DataFrame(np.zeros((_PANDAS_ROWS_LARGE, 4)), columns=list("ABCD"))

self.assertEqual(get_hash(df1), get_hash(df3))
self.assertNotEqual(get_hash(df1), get_hash(df2))

df4 = pd.DataFrame(np.zeros((_PANDAS_ROWS_LARGE, 4)), columns=list("ABCD"))
df5 = pd.DataFrame(np.zeros((_PANDAS_ROWS_LARGE, 4)), columns=list("ABCD"))

self.assertEqual(get_hash(df4), get_hash(df5))
@parameterized.expand(
[
(pd.DataFrame({"foo": [12]}), pd.DataFrame({"foo": [12]}), True),
(pd.DataFrame({"foo": [12]}), pd.DataFrame({"foo": [42]}), False),
(
pd.DataFrame(data={"A": [1, 2, 3], "B": [2, 3, 4]}),
pd.DataFrame(data={"A": [1, 2, 3], "B": [2, 3, 4]}),
True,
),
# Extra column
(
pd.DataFrame(data={"A": [1, 2, 3], "B": [2, 3, 4]}),
pd.DataFrame(data={"A": [1, 2, 3], "B": [2, 3, 4], "C": [1, 2, 3]}),
False,
),
# Different values
(
pd.DataFrame(data={"A": [1, 2, 3], "B": [2, 3, 4]}),
pd.DataFrame(data={"A": [1, 2, 3], "B": [2, 3, 5]}),
False,
),
# Different order
(
pd.DataFrame(data={"A": [1, 2, 3], "B": [2, 3, 4]}),
pd.DataFrame(data={"B": [1, 2, 3], "A": [2, 3, 4]}),
False,
),
# Different index
(
pd.DataFrame(data={"A": [1, 2, 3], "B": [2, 3, 4]}, index=[1, 2, 3]),
pd.DataFrame(data={"A": [1, 2, 3], "B": [2, 3, 4]}, index=[1, 2, 4]),
False,
),
# Missing column
(
pd.DataFrame(data={"A": [1, 2, 3], "B": [2, 3, 4]}),
pd.DataFrame(data={"A": [1, 2, 3]}),
False,
),
# Different sort
(
pd.DataFrame(data={"A": [1, 2, 3], "B": [2, 3, 4]}).sort_values(
by=["A"]
),
pd.DataFrame(data={"A": [1, 2, 3], "B": [2, 3, 4]}).sort_values(
by=["B"], ascending=False
),
False,
),
# Different headers
(
pd.DataFrame(data={"A": [1, 2, 3], "C": [2, 3, 4]}),
pd.DataFrame(data={"A": [1, 2, 3], "B": [2, 3, 4]}),
False,
),
# Reordered columns
(
pd.DataFrame(data={"A": [1, 2, 3], "C": [2, 3, 4]}),
pd.DataFrame(data={"C": [2, 3, 4], "A": [1, 2, 3]}),
False,
),
# Slightly different dtypes
(
pd.DataFrame(
data={"A": [1, 2, 3], "C": pd.array([1, 2, 3], dtype="UInt64")}
),
pd.DataFrame(
data={"A": [1, 2, 3], "C": pd.array([1, 2, 3], dtype="Int64")}
),
False,
),
]
)
def test_pandas_dataframe(self, df1, df2, expected):
result = get_hash(df1) == get_hash(df2)
self.assertEqual(result, expected)

def test_pandas_series(self):
series1 = pd.Series([1, 2])
Expand All @@ -291,6 +364,12 @@ def test_pandas_series(self):

self.assertEqual(get_hash(series4), get_hash(series5))

def test_pandas_series_similar_dtypes(self):
series1 = pd.Series([1, 2], dtype="UInt64")
series2 = pd.Series([1, 2], dtype="Int64")

self.assertNotEqual(get_hash(series1), get_hash(series2))

def test_numpy(self):
np1 = np.zeros(10)
np2 = np.zeros(11)
Expand All @@ -304,6 +383,16 @@ def test_numpy(self):

self.assertEqual(get_hash(np4), get_hash(np5))

def test_numpy_similar_dtypes(self):
np1 = np.ones(10, dtype="u8")
np2 = np.ones(10, dtype="i8")

np3 = np.ones(10, dtype=[("a", "u8"), ("b", "i8")])
np4 = np.ones(10, dtype=[("a", "i8"), ("b", "u8")])

self.assertNotEqual(get_hash(np1), get_hash(np2))
self.assertNotEqual(get_hash(np3), get_hash(np4))

def test_PIL_image(self):
im1 = Image.new("RGB", (50, 50), (220, 20, 60))
im2 = Image.new("RGB", (50, 50), (30, 144, 255))
Expand Down

0 comments on commit 8f84357

Please sign in to comment.