Like Python functools.lru_cache but with a ttl (time to live).

In [43]:
# export
from collections.abc import MutableMapping
import json
import pprint
from operator import itemgetter
from functools import wraps, partial
from typing import Callable, Tuple, Dict, Any, Optional
import sqlite3
import time

In [126]:
# export
class SQLTtlCache(MutableMapping):
    def __init__(
        self,
        dbname=":memory:",
        ttl: Optional[int] = None,  # seconds
        maxsize: Optional[int] = None,
        check_same_thread=False,
        fast=True,
        encoder: Callable = lambda x: json.dumps(x),
        decoder: Callable = lambda x: json.loads(x),
        **kwargs,
    ):
        self.dbname = dbname
        self.conn = sqlite3.connect(
            self.dbname, check_same_thread=check_same_thread, **kwargs
        )
        self.encoder = encoder
        self.decoder = decoder
        self.maxsize = maxsize
        self.ttl = ttl if ttl else 600

        with self.conn as c:
            
            c.execute("PRAGMA foreign_keys = ON;")
            
            c.execute(
                "CREATE TABLE IF NOT EXISTS Cache (key text NOT NULL PRIMARY KEY, value text)"
            )

            c.execute(
                """CREATE TABLE IF NOT EXISTS Ttls ( key text NOT NULL PRIMARY KEY,
                                                     ttl int NOT NULL,
                                                     FOREIGN KEY (key) REFERENCES Cache (key)
                                                        ON DELETE CASCADE
                                                        ON UPDATE CASCADE )"""
            )

            c.execute(
                "CREATE TABLE IF NOT EXISTS Cache (key text NOT NULL PRIMARY KEY, value text)"
            )

            if fast:
                c.execute("PRAGMA journal_mode = 'WAL';")
                c.execute("PRAGMA temp_store = 2;")
                c.execute("PRAGMA synchronous = 1;")
                c.execute(f"PRAGMA cache_size = {-1 * 64_000};")

            if self.maxsize is not None:
                c.execute(
                    f"""
CREATE TRIGGER IF NOT EXISTS maxsize_control 
   BEFORE INSERT
   ON Cache
   WHEN (SELECT COUNT(*) FROM Cache) >= {self.maxsize}
BEGIN
    DELETE FROM Cache WHERE rowid = (SELECT min(rowid) FROM Cache);
END;"""
                )

    def __setitem__(self, key, value):
        with self.conn as c:
            c.execute(
                "INSERT OR REPLACE INTO  Cache VALUES (?, ?)",
                (key, self.encoder(value)),
            )

            # Do NOT update TTL -> OR IGNORE
            c.execute(
                f"INSERT OR IGNORE INTO Ttls VALUES (?, (strftime('%s','now') + {self.ttl}))",
                (key,),
            )

    def __getitem__(self, key):
        
        result = None
        
        with self.conn as c:
            ttl = c.execute("SELECT ttl FROM Ttls WHERE key = ?", (key,)).fetchone()
            
            if (ttl is None) or (ttl[0] < time.time()):
                c.execute("DELETE FROM Cache WHERE key = ?", (key,))
                raise KeyError(key)
            
            result = c.execute("SELECT value FROM Cache WHERE key = ?", (key,)).fetchone()
        
        if result is None:
            raise KeyError(key)
        return self.decoder(result[0])

    def __delitem__(self, key):
        if key not in self:
            raise KeyError(key)
        with self.conn as c:
            c.execute("DELETE FROM Cache WHERE key=?", (key,))

    def __len__(self):
        return next(self.conn.execute("SELECT COUNT(*) FROM Cache"))[0]

    def __iter__(self):
        c = self.conn.execute("SELECT key FROM Cache")
        return map(itemgetter(0), c.fetchall())

    def __repr__(self):
        return f"{type(self).__name__}(dbname={self.dbname!r}, items={pprint.pformat(list(self.items()))})"

    def vacuum(self):
        self.conn.execute("VACUUM;")

    def clear(self):
        self.conn.execute("DELETE FROM Cache")

    def close(self):
        self.conn.close()

In [127]:
c = SQLTtlCache(ttl=20)

In [128]:
c["asd"] = 123

In [129]:
c

SQLTtlCache(dbname=':memory:', items=[('asd', 123)])

In [130]:
a = (1, "asdas", 3)
kw = {"bb": 123, "aa": 34, "xs": "bm"}

In [131]:
sorted(kw.items())

[('aa', 34), ('bb', 123), ('xs', 'bm')]

In [132]:
str((a,kw))

"((1, 'asdas', 3), {'bb': 123, 'aa': 34, 'xs': 'bm'})"

Adapted from the Python source code

In [152]:
# export
def make_key(
    args,
    kwargs,
    kwd_mark=("::",),
    fasttypes={int, str},
):
    """Make a cache key from optionally typed positional and keyword arguments

    The key is constructed in a way that is flat as possible rather than
    as a nested structure that would take more memory.

    If there is only a single argument and its data type is known to cache
    its hash value, then that argument is returned without a wrapper.  This
    saves space and improves lookup speed.

    """
    # All of code below relies on kwds preserving the order input by the user.
    # Formerly, we sorted() the kwds before looping.  The new way is *much*
    # faster; however, it means that f(x=1, y=2) will now be treated as a
    # distinct call from f(y=2, x=1) which will be cached separately.
    key = args
    if kwargs:
        key += kwd_mark
        for item in kwargs.items():
            key += item
    return key

Without kwargs

In [134]:
make_key((1,), None)

(1,)

With kwargs

In [135]:
make_key((1,2,3), {"a": 12, "b": "mm"})

(1, 2, 3, '::', 'a', 12, 'b', 'mm')

I can't use `cache_key = str((func.__name__, args + tuple(kwargs.values())))` because if we call the function once with the correct `kwargs` then we could call it with an incorrect `kwarg` name and it would still return a value.

In [136]:
# export
def lru_cache(func: Callable = None, ttl: int = None, maxsize: int = None) -> Callable:

    # print(func)
    # print(maxsize)

    cache = SQLTtlCache(":memory:", maxsize=maxsize, ttl=ttl)
    hits = 0

    if func is None:
        return partial(lru_cache, maxsize=maxsize, ttl=ttl)

    @wraps(func)
    def wrapper(*args: Any, **kwargs: Any) -> Any:

        # print(f"Len: {len(cache)}")

        cache_key = str((func.__name__,) + make_key(args, kwargs))

        # print(f"Cache key: {cache_key}")

        if cache.get(cache_key) is not None:
            nonlocal hits
            hits += 1
            print("Hit!")
            return cache[cache_key]

        result = func(*args, **kwargs)

        cache[cache_key] = result

        return result

    wrapper.cache_info = lambda: (len(cache), hits)
    wrapper.cache_clear = lambda: cache.clear()
    wrapper.cache = cache

    return wrapper

Check normal cache behaviour

In [137]:
@lru_cache(maxsize=20)
def add(a, b, c=None):
    
    if c is None:
        c = 0
    
    return a + b + c

In [138]:
add(1,2,3)

for i in range(19):
    add(1, i, c=i)

assert add.cache_info()[0] == 20

add(200, 150, 100)

# maxsize works

assert add.cache_info()[0] == 20

**TODO** hits not working

In [139]:
add.cache_info()

(20, 0)

In [140]:
# cache clearing works

add.cache_clear()

assert add.cache_info()[0] == 0

In [141]:
# incorrect kwars are not cached

try:
    add(1, b=24, x=4)
except TypeError:
    print("pass")

pass


In [142]:
@lru_cache
def add(a, b, c=None):
    
    if c is None:
        c = 0
    
    return a + b + c

In [143]:
add(1,2,3)

for i in range(200):
    add(1, i, c=i)

assert add.cache_info()[0] == 201

add(200, 150, 100)

# maxsize works

assert add.cache_info()[0] == 202

In [144]:
# cache clearing works

add.cache_clear()

assert add.cache_info()[0] == 0

Check ttl

In [145]:
@lru_cache(ttl=2, maxsize=20)
def add(a, b, c=None):
    
    if c is None:
        c = 0
    
    return a + b + c

In [146]:
add(1,2,3)

6

In [147]:
add(1,2,3)

Hit!


6

In [148]:
time.sleep(2)

In [149]:
add(1,2,3)

6

In [150]:
for i in range(19):
    add(1, i, c=i)

assert add.cache_info()[0] == 20

add(200, 150, 100)

# maxsize works

assert add.cache_info()[0] == 20

In [151]:
# cache clearing works

add.cache_clear()

assert add.cache_info()[0] == 0

**Benchmarks**

In [19]:
from functools import lru_cache as pycache

In [20]:
@lru_cache(maxsize=None)
def add1(a, b, c=None):

    if c is None:
        c = 0

    return a + b + c


@pycache(maxsize=None)
def add2(a, b, c=None):

    if c is None:
        c = 0

    return a + b + c

In [21]:
%%timeit

for i in range(200):
    add1(1, i, c=i)

2.6 ms ± 92.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [22]:
%%timeit

for i in range(200):
    add2(1, i, c=i)

55 µs ± 5.45 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [26]:
@lru_cache(maxsize=1000)
def add1(a, b, c=None):

    if c is None:
        c = 0

    return a + b + c


@pycache(maxsize=1000)
def add2(a, b, c=None):

    if c is None:
        c = 0

    return a + b + c

In [27]:
%%timeit

for i in range(2000):
    add1(1, i, c=i)

58.6 ms ± 479 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [28]:
%%timeit

for i in range(2000):
    add2(1, i, c=i)

1.17 ms ± 86.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
