Reimplementation of Python functools.lru_cache

In [1]:
# export
from collections.abc import MutableMapping
import json
import pprint
from functools import wraps, partial
from typing import Callable, Tuple, Dict, Any
import sqlite3
from operator import itemgetter

In [2]:
# import orjson as json

In [29]:
# export
class SQLCache(MutableMapping):
    def __init__(
        self,
        dbname,
        maxsize: int = None,
        check_same_thread=False,
        fast=True,
        encoder: Callable = lambda x: json.dumps(x),
        decoder: Callable = lambda x: json.loads(x),
        **kwargs,
    ):
        self.dbname = dbname
        self.conn = sqlite3.connect(
            self.dbname, check_same_thread=check_same_thread, **kwargs
        )
        self.encoder = encoder
        self.decoder = decoder
        self.maxsize = maxsize

        with self.conn as c:
            c.execute(
                "CREATE TABLE IF NOT EXISTS Cache (key text NOT NULL PRIMARY KEY, value text)"
            )

            if fast:
                c.execute("PRAGMA journal_mode = 'WAL';")
                c.execute("PRAGMA temp_store = 2;")
                c.execute("PRAGMA synchronous = 1;")
                c.execute(f"PRAGMA cache_size = {-1 * 64_000};")

            if maxsize is not None:
                c.execute(
                    f"""
CREATE TRIGGER IF NOT EXISTS maxsize_control 
   BEFORE INSERT
   ON Cache
   WHEN (SELECT COUNT(*) FROM Cache) >= {self.maxsize}
BEGIN
    DELETE FROM Cache WHERE rowid = (SELECT min(rowid) FROM Cache);
END;"""
                )

    def __setitem__(self, key, value):
        with self.conn as c:
            c.execute(
                "INSERT OR REPLACE INTO  Cache VALUES (?, ?)",
                (key, self.encoder(value)),
            )

    def __getitem__(self, key):
        c = self.conn.execute("SELECT value FROM Cache WHERE Key=?", (key,))
        row = c.fetchone()
        if row is None:
            raise KeyError(key)
        return self.decoder(row[0])

    def __delitem__(self, key):
        if key not in self:
            raise KeyError(key)
        with self.conn as c:
            c.execute("DELETE FROM Cache WHERE key=?", (key,))

    def __len__(self):
        return next(self.conn.execute("SELECT COUNT(*) FROM Cache"))[0]

    def __iter__(self):
        c = self.conn.execute("SELECT key FROM Cache")
        return map(itemgetter(0), c.fetchall())

    def __repr__(self):
        return f"{type(self).__name__}(dbname={self.dbname!r}, items={pprint.pformat(list(self.items()))})"

    def vacuum(self):
        self.conn.execute("VACUUM;")

    def clear(self):
        self.conn.execute("DELETE FROM Cache")

    def close(self):
        self.conn.close()

In [30]:
a = (1, "asdas", 3)
kw = {"bb": 123, "aa": 34, "xs": "bm"}

In [31]:
sorted(kw.items())

[('aa', 34), ('bb', 123), ('xs', 'bm')]

In [32]:
str((a,kw))

"((1, 'asdas', 3), {'bb': 123, 'aa': 34, 'xs': 'bm'})"

Adapted from the Python source code

In [33]:
def make_key(
    args,
    kwargs,
    kwd_mark=("::",),
    fasttypes={int, str},
):
    """Make a cache key from optionally typed positional and keyword arguments

    The key is constructed in a way that is flat as possible rather than
    as a nested structure that would take more memory.

    If there is only a single argument and its data type is known to cache
    its hash value, then that argument is returned without a wrapper.  This
    saves space and improves lookup speed.

    """
    # All of code below relies on kwds preserving the order input by the user.
    # Formerly, we sorted() the kwds before looping.  The new way is *much*
    # faster; however, it means that f(x=1, y=2) will now be treated as a
    # distinct call from f(y=2, x=1) which will be cached separately.
    key = args
    if kwargs:
        key += kwd_mark
        for item in kwargs.items():
            key += item
    return key

Without kwargs

In [34]:
make_key((1,), None)

(1,)

With kwargs

In [35]:
make_key((1,2,3), {"a": 12, "b": "mm"})

(1, 2, 3, '::', 'a', 12, 'b', 'mm')

I can't use `cache_key = str((func.__name__, args + tuple(kwargs.values())))` because if we call the function once with the correct `kwargs` then we could call it with an incorrect `kwarg` name and it would still return a value.

In [36]:
def lru_cache(func: Callable = None, maxsize: int = None) -> Callable:

    # print(func)
    # print(maxsize)

    cache = SQLCache(":memory:", maxsize=maxsize)

    if func is None:
        return partial(lru_cache, maxsize=maxsize)

    @wraps(func)
    def wrapper(*args: Any, **kwargs: Any) -> Any:

        # print(f"Len: {len(cache)}")

        cache_key = str((func.__name__,) + make_key(args, kwargs))

        # print(f"Cache key: {cache_key}")

        if cache_key in cache:
            # print("Hit!")
            return cache[cache_key]

        result = func(*args, **kwargs)

        cache[cache_key] = result

        return result

    wrapper.cache_size = lambda: len(cache)
    wrapper.cache_clear = lambda: cache.clear()
    wrapper.cache = cache

    return wrapper

In [37]:
@lru_cache(maxsize=20)
def add(a, b, c=None):
    
    if c is None:
        c = 0
    
    return a + b + c

In [38]:
add(1,2,3)

for i in range(19):
    add(1, i, c=i)

assert add.cache_size() == 20

add(200, 150, 100)

# maxsize works

assert add.cache_size() == 20

In [39]:
# cache clearing works

add.cache_clear()

assert add.cache_size() == 0

In [40]:
# incorrect kwars are not cached

try:
    add(1, b=24, x=4)
except TypeError:
    print("pass")

pass


In [41]:
@lru_cache
def add(a, b, c=None):
    
    if c is None:
        c = 0
    
    return a + b + c

In [42]:
add(1,2,3)

for i in range(200):
    add(1, i, c=i)

assert add.cache_size() == 201

add(200, 150, 100)

# maxsize works

assert add.cache_size() == 202

In [43]:
add(1,2,3)

6

In [24]:
# cache clearing works

add.cache_clear()

assert add.cache_size() == 0

**Benchmarks**

In [19]:
from functools import lru_cache as pycache

In [20]:
@lru_cache(maxsize=None)
def add1(a, b, c=None):

    if c is None:
        c = 0

    return a + b + c


@pycache(maxsize=None)
def add2(a, b, c=None):

    if c is None:
        c = 0

    return a + b + c

In [21]:
%%timeit

for i in range(200):
    add1(1, i, c=i)

2.6 ms ± 92.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [22]:
%%timeit

for i in range(200):
    add2(1, i, c=i)

55 µs ± 5.45 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [26]:
@lru_cache(maxsize=1000)
def add1(a, b, c=None):

    if c is None:
        c = 0

    return a + b + c


@pycache(maxsize=1000)
def add2(a, b, c=None):

    if c is None:
        c = 0

    return a + b + c

In [27]:
%%timeit

for i in range(2000):
    add1(1, i, c=i)

58.6 ms ± 479 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [28]:
%%timeit

for i in range(2000):
    add2(1, i, c=i)

1.17 ms ± 86.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


From the Python source code

In [306]:
class _HashedSeq(list):
    """This class guarantees that hash() will be called no more than once
    per element.  This is important because the lru_cache() will hash
    the key multiple times on a cache miss.

    """

    __slots__ = "hashvalue"

    def __init__(self, tup, hash=hash):
        self[:] = tup
        self.hashvalue = hash(tup)

    def __hash__(self):
        return self.hashvalue


def _make_key(
    args,
    kwds,
    typed,
    kwd_mark=(object(),),
    fasttypes={int, str},
    tuple=tuple,
    type=type,
    len=len,
):
    """Make a cache key from optionally typed positional and keyword arguments

    The key is constructed in a way that is flat as possible rather than
    as a nested structure that would take more memory.

    If there is only a single argument and its data type is known to cache
    its hash value, then that argument is returned without a wrapper.  This
    saves space and improves lookup speed.

    """
    # All of code below relies on kwds preserving the order input by the user.
    # Formerly, we sorted() the kwds before looping.  The new way is *much*
    # faster; however, it means that f(x=1, y=2) will now be treated as a
    # distinct call from f(y=2, x=1) which will be cached separately.
    key = args
    if kwds:
        key += kwd_mark
        for item in kwds.items():
            key += item
    if typed:
        key += tuple(type(v) for v in args)
        if kwds:
            key += tuple(type(v) for v in kwds.values())
    elif len(key) == 1 and type(key[0]) in fasttypes:
        return key[0]
    return _HashedSeq(key)