diff --git a/tensorly/tenalg/einsum_tenalg/caching.py b/tensorly/tenalg/einsum_tenalg/caching.py new file mode 100644 index 000000000..6fb48bf75 --- /dev/null +++ b/tensorly/tenalg/einsum_tenalg/caching.py @@ -0,0 +1,20 @@ +EINSUM_PATH_CACHE = dict() + + +def einsum_path_cached(fun): + def wrapped(key, *args, **kwargs): + name = fun.__name__ + try: + cache = EINSUM_PATH_CACHE[name] + except KeyError: + EINSUM_PATH_CACHE[name] = dict() + cache = EINSUM_PATH_CACHE[name] + try: + equation = cache[key] + except KeyError: + equation = fun(*args, **kwargs) + cache[key] = equation + + return equation + + return wrapped