# Fast exponentiation

Write a function `exp(a, n)` that raises an integer to the power `n` in less than `O(n)` multiplications.


In [None]:
def exp(a: int, n: int) -> int:
    """
    Brute force: a^n = a * a * ... * a [n times]

    We can do it in O(logn) with repeated squaring.

    For example, with a = 5, n = 4:
      5^4 = 5 * 5 * 5 * 5 = 625       4 ops
      5^4 = (5^2)^2 = (5*5) * (5*5)   2 ops

    With a = 3, n = 7:
      3^7 = 3 * 3 * 3 * 3 * 3 * 3 *3  7 ops
          = (3^4) * (3^2) * 3
          = (3*3) * ((3*3)*(3*3)) * 3 4 ops

    Plan:
      1. Base case: n = 1, return a
      2. If n is odd, return a * exp(a, n - 1)
      3. If n is even, return r * r where r = exp(a, n >> 1)
    """
    if n == 0:
        return 1
    if n == 1:
        return a
    if n % 2 == 1:
        return a * exp(a, n - 1)
    else:
        r = exp(a, n >> 1)
        return r * r

In [None]:
cases = ((5, 4, 625), (3, 7, 2187))
for a, n, want in cases:
    got = exp(a, n)
    assert got == want, got

In [None]:
def exp_iter(a: int, n: int) -> int:
    acc = 1
    while n > 0:
        if n % 2 == 1:
            acc *= a
            n -= 1
        else:
            a *= a
            n >>= 1

    return acc

In [None]:
assert exp_iter(5, 0) == 1
assert exp_iter(5, 1) == 5, exp_iter(5, 1)
assert exp_iter(5, 2) == 25, exp_iter(5, 2)
assert exp_iter(5, 3) == 125, exp_iter(5, 3)
assert exp_iter(5, 4) == 625, exp_iter(5, 4)

assert exp_iter(2, 5) == 32, exp_iter(2, 5)
assert exp_iter(4, 4) == 256, exp_iter(4, 4)

In [None]:
def exp_recur(a: int, n: int) -> int:
    if n == 0:
        return 1
    elif n % 2 == 1:
        return a * exp_recur(a, n - 1)
    else:
        r = exp_recur(a, n >> 1)
        return r * r

In [None]:
assert exp_recur(5, 0) == 1
assert exp_recur(5, 1) == 5, exp_recur(5, 1)
assert exp_recur(5, 2) == 25, exp_recur(5, 2)
assert exp_recur(5, 3) == 125, exp_recur(5, 3)
assert exp_recur(5, 4) == 625, exp_recur(5, 4)

assert exp_recur(2, 5) == 32, exp_recur(2, 5)
assert exp_recur(4, 4) == 256, exp_recur(4, 4)

In [None]:
from functools import cache


@cache
def exp_cache(a: int, n: int) -> int:
    if n == 0:
        return 1
    if n == 1:
        return a
    left = n >> 1
    right = n - left
    return exp_cache(a, left) * exp_cache(a, right)


def exp_reset_cache(a: int, n: int) -> int:
    exp_cache.cache_clear()
    return exp_cache(a, n)

In [None]:
import timeit
import matplotlib.pyplot as plt

n_values = range(1, 201)

# Measure the running time of exp and exp_iter
exp_recur_times = []
exp_iter_times = []


for n in range(1, 201):
    exp_recur_time = timeit.timeit(lambda: exp_recur(6, n), number=1000)
    exp_recur_times.append(exp_recur_time)

    exp_iter_time = timeit.timeit(lambda: exp_iter(3, n), number=1000)
    exp_iter_times.append(exp_iter_time)


# Plot the results
plt.plot(n_values, exp_recur_times, label="exp_recur")
plt.plot(n_values, exp_iter_times, label="exp_iter")
plt.xlabel("n")
plt.ylabel("Time (seconds)")
plt.title("Running Time of exp_recur vs exp_iter")
plt.legend()
plt.show()