In [None]:
from IPython.core.display import HTML
with open('../style.css') as file:
    css = file.read()
HTML(css)

# Integer Square Root

The function `isqrt(n)` takes one natural numbers $n$ and returns the largest natural number $r$ such that
$r^2 \leq n$, i.e. we have
$$ \texttt{isqrt}(n) := \max\bigl(\{ r \in \mathbb{N} \mid r^2 \leq n \}\bigr). $$
Our goal is to compute `isqrt(n)` recursively via a *divide-and-conquer* algorithm as follows:
1. $\texttt{isqrt}(0) = 0$.
2. $\bigl(2 \cdot \texttt{isqrt}(n \,\texttt{//}\, 4) + 1\bigr)^2 \leq n \rightarrow \texttt{isqrt}(n) = 2 \cdot \texttt{isqrt}(n \,\texttt{//}\, 4) + 1$.
3. $\bigl(2 \cdot \texttt{isqrt}(n \,\texttt{//}\, 4) + 1\bigr)^2 > n \rightarrow \texttt{isqrt}(n) = 2 \cdot \texttt{isqrt}(n \,\texttt{//}\, 4)$.

The function `rsqrt` computes the integer square root of the numnber `n` recursively.

In [None]:
def rsqrt(n):
    if n == 0:
        return 0
    r = 2 * rsqrt(n // 4)
    if (r + 1) ** 2 <= n:
        return r + 1
    else:
        return r

In [None]:
for n in range(10):
    print(f'rsqrt({n}) = {rsqrt(n)}')

In order to test our implementation more thoroughly we use random numbers.

In [None]:
import random
random.seed(0)

The function `run_tests(no_tests, f)` generates `no_tests` integers `n` and tests, whether 
`f(n)` is the *integer square root* of `n` in each case.

In [None]:
def run_tests(no_tests, f):
    for i in range(no_tests):
        n = random.randrange(2 ** 32)
        r = f(n)
        assert r * r <= n and (r + 1)**2 > n, f'Error: {r} != isqrt({n})'

In [None]:
%%time
run_tests(10**6, rsqrt)

The function `sqrt2(k)` returns $\sqrt{2}$ up to `k` decimal places.

In [None]:
def sqrt2(k, f):
    n = 2 * 10 ** (2 * k)
    r = f(n)
    s = str(r)
    return s[0] + '.' + s[1:]

Let us compute the first $800$ digits of $\sqrt{2\,}$.

In [None]:
import sys
sys.setrecursionlimit(6000)

In [None]:
%%time
sqrt2(800, rsqrt)

You can compare this with the results shown on https://catonmat.net/tools/generate-sqrt2-digits.

The recursive implementation of `isqrt(n)` is based on the formula
 $$  \texttt{isqrt}(n) = \left\{
     \begin{array}{ll}
       2 \cdot \texttt{isqrt}(n \,\texttt{//}\, 4) + 1 & 
              \mbox{if $\bigl(2 \cdot \texttt{isqrt}(n \,\texttt{//}\, 4) + 1\bigr)^2 \leq n$;} \\
       2 \cdot \texttt{isqrt}(n \,\texttt{//}\, 4)     &
              \mbox{otherwise.}
     \end{array}
     \right.
$$
In each of these two cases, $\texttt{isqrt}(n)$ is computed in terms of $\texttt{isqrt}(n \,\texttt{//}\, 4)$.
The number `n // 4` results from the number `n` by cutting of the last two bits.  If we want to transform our
recursive implementation into an iterative implementation, then the iterative implementation needs to add two
bits of `n` in every iteration.  Therefore, we first implement an auxiliary function next.  This function
is used to calculate the bits of `n`.

The function `list_of_digits` returns a list of digits representing `n` in base 4, i.e.
if `n` is given as
$$ n = \sum\limits_{i=0}^k d_i \cdot 4^i \quad\mbox{where $0 \leq d_i < 4$}$$
then we have
$$ \texttt{list_of_digits}(n) = [d_0, d_1, \cdots, d_k]. $$

In [None]:
def list_of_digits(n):
    L = []
    while n > 0:
        L += [n % 4]
        n  = n // 4
    return L

In [None]:
list_of_digits(18)

Assume that in base 4 the number $n$ is given as
$$ n = \sum\limits_{j=0}^k d_j \cdot 4^j \quad\mbox{where $0 \leq d_j < 4$}$$
Let us denote by $m_i$ and $r_i$ the values of the variable `m` and `r` at the 
beginning of the $(i+1)^\mathrm{th}$ iteration of the `while`-loop.  Then the following
invariants hold:
* $L = [d_0, d_1, \cdots, d_{k-i}]$,
* $m_i = n \;\texttt{//}\; 4^{k+1-i}$,
* $r_i = \texttt{isqrt}(m_i)$.

As the loop ends after $k+1$ iterations, the final value of `r` is 
$$\texttt{isqrt}(m_{k+1}) = \texttt{isqrt}(n \;\texttt{//}\; 4^{k+1-(k+1)}) = \texttt{isqrt}(n).$$

In [None]:
def isr(n):
    L = list_of_digits(n)
    r = 0
    m = 0
    while len(L) > 0:
        m = 4 * m + L[-1]
        L = L[:-1]
        if (2 * r + 1) ** 2 <= m:
            r = 2 * r + 1
        else:
            r = 2 * r
    return r

In [None]:
for n in range(100):
    print(f'isr({n}) = {isr(n)}')

In [None]:
%%time
run_tests(10**6, isr)

In [None]:
%%time
sqrt2(100, isr)

In [None]:
%%time
sqrt2(10000, isr)