## Babylonian Square Root Algorithm


2019-10-26 - ref. https://www.youtube.com/watch?v=vAp6nUMrKYg

Let's start with a simple example, the computation of $\sqrt(x)$ where how autodiff works comes as both a mathematical surprise, and a computing wonder.  
The example is the Babylonian algorithm, known to mankind for millenia, to compute $\sqrt(x)$:

repeat  
  $t \leftarrow \frac{(t + \frac{x}{t})}{2}$  
until t converges to $\sqrt(x)$

Each iteration has one add and two divides.  
For illustration purposes, 10 iterations suffice

In [1]:
function babylonian(x; n = 10)
    t = (1 + x) / 2.
    for i = 2:n
        t = (t + x/t) / 2.
    end
    t
end

babylonian (generic function with 1 method)

Check that it works:

In [2]:
α = π
babylonian(α), √α

(1.7724538509055159, 1.7724538509055159)

In [3]:
for α in 2:3
    println(babylonian(α), " ", √α)
end

1.414213562373095 1.4142135623730951
1.7320508075688772 1.7320508075688772


In [4]:
using Plots
plotly()

Plots.PlotlyBackend()

In [5]:
# WARN: first plots require to load package which takes time

ix = 0:.01:49

plot([x -> babylonian(x, n=i) for i = 1:5], 
    ix, 
    label=["Iteration $jx" for _i = 1:1, jx = 1:5])

plot!(sqrt, 
    ix, 
    c="darkslateblue", 
    label="sqrt", 
    title = "Those Babylonian really knew about √")

## And now the derivative, almost by magic

In a few lines of code. No mention of $\frac{1}{2}$ over  $\sqrt(x)$. We will use the "dual number" denoted as $D$ in what follows, those where invented by the famous algebraist Clifford in 1873.

In [6]:
struct D <: Number # D is a <function, derivative> pair (a pair of floats)
    f::Tuple{Float64, Float64}
end

- Sum Rule: $(x + y)' = x' + y'$
- Quotient Rule: $(\frac{x}{y})' = \frac{yx' - xy'}{y^2}$

In [7]:
import Base: +, /, convert, promote_rule

# overload: 
+(x::D, y::D) = D(x.f .+ y.f)
/(x::D, y::D) = D((x.f[1] / y.f[1], (y.f[1] * x.f[2] - x.f[1] * y.f[2]) / y.f[1]^2))
convert(::Type{D}, x::Real) = D((x, zero(x))) # convert ordinary number to Dual number, intro. 0 for derivative
promote_rule(::Type{D}, ::Type{<:Number}) = D # then promote...

promote_rule (generic function with 159 methods)

The same algorithm with no rewrite at all computes properly the derivative as the check shows:

In [8]:
x = 49; babylonian(D((x, 1))), (√x, .5 / √x)

(D((7.0, 0.07142857142857142)), (7.0, 0.07142857142857142))

In [9]:
x = π; babylonian(D((x, 1))), (√x, .5 / √x)

(D((1.7724538509055159, 0.28209479177387814)), (1.7724538509055159, 0.28209479177387814))

## It just works!
How does this work?  
We will explain in  a moment. Right now marvel that it does. Note we did not import any autodiff package. Everything is just basic Julia.

## The assembler

Most folks don't read assembler, but one can see that it is short (minus comments). The shortness is a clue that suggests speed!

In [10]:
@inline function babylonian(x; n = 10)
    t = (1 + x) / 2.
    for i = 2:n; t = (t + x/t) / 2. end
    t
end

babylonian (generic function with 1 method)

In [11]:
@code_native babylonian(D((2, 1)))

	.text
; ┌ @ In[10]:2 within `babylonian'
; │┌ @ In[10]:2 within `#babylonian#10'
; ││┌ @ promotion.jl:313 within `+' @ In[7]:4
; │││┌ @ broadcast.jl:798 within `materialize'
; ││││┌ @ broadcast.jl:1002 within `copy'
; │││││┌ @ ntuple.jl:42 within `ntuple'
; ││││││┌ @ broadcast.jl:1002 within `#19'
; │││││││┌ @ broadcast.jl:598 within `_broadcast_getindex'
; ││││││││┌ @ broadcast.jl:625 within `_broadcast_getindex_evalf'
; │││││││││┌ @ In[10]:2 within `+'
	vmovsd	(%rsi), %xmm1           # xmm1 = mem[0],zero
	vmovsd	8(%rsi), %xmm2          # xmm2 = mem[0],zero
	movabsq	$140303695308080, %rax  # imm = 0x7F9AFFEB6530
	vaddsd	(%rax), %xmm1, %xmm4
	vxorpd	%xmm8, %xmm8, %xmm8
	vaddsd	%xmm8, %xmm2, %xmm5
	movabsq	$140303695308088, %rax  # imm = 0x7F9AFFEB6538
; │└└└└└└└└└
; │┌ @ float.jl:401 within `#babylonian#10'
	vmovsd	(%rax), %xmm3           # xmm3 = mem[0],zero
	vmulsd	%xmm3, %xmm4, %xmm6
; │└
; │┌ @ In[10]:2 within `#babylonian#10'
; ││┌ @ promotion.jl:316 within `/' @ float.jl:399
	va

## Symbolically
We haven't yet explained how it works, but it may be of some calue to understand that the below is mathematically equivqlent, though not what the computation is doing.    
Notice in the below that babylonian works on SymPy Symbols.  
Note: Python and Julia are very good friends.It's not a competition! Watch how nicely we can use the sane code now in SymPy.

In [12]:
# using Pkg
# Pkg.add("SymPy")

# using PyCall
using SymPy

In [13]:
x = symbols("x")
display("Iterations as a function of x")
for k = 1:5
    display(simplify(babylonian(x, n=k)))
end

"Iterations as a function of x"

0.5⋅x + 0.5

         1.0⋅x       
0.25⋅x + ───── + 0.25
         x + 1       

    ⎛          4           3            2                      ⎞
1.0⋅⎝0.015625⋅x  + 0.4375⋅x  + 1.09375⋅x  + 0.4375⋅x + 0.015625⎠
────────────────────────────────────────────────────────────────
                    3          2                                
             0.125⋅x  + 0.875⋅x  + 0.875⋅x + 0.125              

    ⎛                8                  7                   6                 
1.0⋅⎝6.103515625e-5⋅x  + 0.00732421875⋅x  + 0.111083984375⋅x  + 0.48876953125⋅
──────────────────────────────────────────────────────────────────────────────
                                 7                 6                 5        
                   0.0009765625⋅x  + 0.0341796875⋅x  + 0.2666015625⋅x  + 0.698

 5                    4                  3                   2                
x  + 0.7855224609375⋅x  + 0.48876953125⋅x  + 0.111083984375⋅x  + 0.00732421875
──────────────────────────────────────────────────────────────────────────────
         4                 3                 2                                
2421875⋅x  + 0.6982421875⋅x  + 0.2666015625⋅x  + 0.0341796875⋅x + 0.0009765625

                   ⎞
⋅x + 6.103515625e-5⎠
────────────────────
                    
                    

    ⎛                      16                        15                       
1.0⋅⎝9.31322574615479e-10⋅x   + 4.61935997009277e-7⋅x   + 3.34903597831726e-5⋅
──────────────────────────────────────────────────────────────────────────────
                                      15                        14            
                 2.98023223876953e-8⋅x   + 4.61935997009277e-6⋅x   + 0.0001875

 14                        13                        12                       
x   + 0.00084395706653595⋅x   + 0.00979593023657799⋅x   + 0.0600817054510117⋅x
──────────────────────────────────────────────────────────────────────────────
             13                        12                       11            
46014785767⋅x   + 0.00313469767570496⋅x   + 0.0261224806308746⋅x   + 0.1201634

11                      10                      9                      8      
   + 0.210285969078541⋅x   + 0.439058616757393⋅x  + 0.559799736365676⋅x  + 0.4
──────────────────────────────────────────────────

In [14]:
display("Derivatives as a function of x")
for k = 1:5
    display(simplify(diff(simplify(babylonian(x, n=k)), x)))
end

"Derivatives as a function of x"

0.500000000000000

         1.0   
0.25 + ────────
              2
       (x + 1) 

    ⎛             6               5                4              3           
1.0⋅⎝0.001953125⋅x  + 0.02734375⋅x  + 0.287109375⋅x  + 0.6640625⋅x  + 0.732421
──────────────────────────────────────────────────────────────────────────────
                       6            5             4           3             2 
             0.015625⋅x  + 0.21875⋅x  + 0.984375⋅x  + 1.5625⋅x  + 0.984375⋅x  

     2                             ⎞
875⋅x  + 0.24609375⋅x + 0.041015625⎠
────────────────────────────────────
                                    
+ 0.21875⋅x + 0.015625              

    ⎛                     14                        13                        
1.0⋅⎝5.96046447753906e-8⋅x   + 4.17232513427734e-6⋅x   + 0.000190675258636475⋅
──────────────────────────────────────────────────────────────────────────────
                          14                        13                        
      9.5367431640625e-7⋅x   + 6.67572021484375e-5⋅x   + 0.00168895721435547⋅x

 12                        11                       10                      9 
x   + 0.00312113761901855⋅x   + 0.0261631608009338⋅x   + 0.120073914527893⋅x  
──────────────────────────────────────────────────────────────────────────────
12                       11                      10                      9    
   + 0.0195884704589844⋅x   + 0.120171546936035⋅x   + 0.420557022094727⋅x  + 0

                     8                      7                      6          
+ 0.323666036128998⋅x  + 0.526678562164307⋅x  + 0.527062118053436⋅x  + 0.32336
──────────────────────────────────────────────────

    ⎛                      30                         29                      
1.0⋅⎝2.77555756156289e-17⋅x   + 8.60422844084496e-15⋅x   + 1.65975566623899e-1
──────────────────────────────────────────────────────────────────────────────
                           30                         29                      
     8.88178419700125e-16⋅x   + 2.75335310107039e-13⋅x   + 3.25171001236413e-1

   28                         27                        26                    
2⋅x   + 1.34642408333718e-10⋅x   + 5.97234897647958e-9⋅x   + 1.61239274432123e
──────────────────────────────────────────────────────────────────────────────
   28                        27                       26                      
1⋅x   + 1.91952764794223e-9⋅x   + 6.5691122763667e-8⋅x   + 1.42430069338673e-6

    25                        24                        23                    
-7⋅x   + 2.84861239072121e-6⋅x   + 3.45901327480913e-5⋅x   + 0.000299103337618
──────────────────────────────────────────────────

Let's by hand take the "derivative" of the babylonian iteration with respect to x.  
Specifically $t' = \frac{dt}{dx}$. This is the old fashioned way of a human writing code.  

In [15]:
function dbabylonian(x; n = 10)
    t = (1 + x) / 2.
    dt = 1 / 2.
    
    for i = 2:n 
        t = (t + x/t) / 2.
        dt = (dt + (t - x * dt) / t^2) / 2.
    end
    dt
end

dbabylonian (generic function with 1 method)

Note: 
$t = \frac{1}{2} \times (t + \frac{x}{t})$  

then $(\frac{dt}{dx})' = \frac{1}{2} \times (t' + \frac{x' \times t - x \times t'}{t^2})$, as $x' = 1$, we get:  
$(\frac{dt}{dx})' = \frac{1}{2} \times (\frac{t' + (t - x \times t')}{t^2})$

See this rewrittent code gets the right answer.  
So the trick is for the computer system to do it for us, and without any loss of speed or convenience.

In [16]:
x= π; dbabylonian(x), .5 / √x

(0.28209479177387814, 0.28209479177387814)

Waht just happened?  
Answer: we created an iteration by hand for t' given our iteration for t, Then we ran the iteration alongside the iteration for t.

In [17]:
babylonian(D((x, 1)))

D((1.7724538509055159, 0.28209479177387814))

How did this work?  
It created the same derivative iteration that we did by hand, using very general rules that are set once and need not be written by hand (and multiple dispatch).  

**Important** the derivative is substituted before the JIT compiler, and thus efficient compiled code is executed.

## Dual Number Notation

Instead of $D(a, b)$ we can write: $a + b \epsilon$ where $\epsilon$ satisfies $\epsilon^2 = 0$. Some people like to recall imaginary numbers where an $i$ is introduced with $i^2 = -1$. Others like to think of how engineers just fdrop the $O(\epsilon^2)$ terms.  
The four rules are:  
- 1 & 2. $(a + b\epsilon) \pm (c + d\epsilon) = (a + c) \pm (b + d)\epsilon$
- 3. $(a + b\epsilon) \times (c + d\epsilon) = (ac) + (bc + ad)\epsilon$
- 4. $\frac{(a + b\epsilon)}{(c + d\epsilon)} = (\frac{a}{c}) + \frac{(bc - ad)}{d^2}\epsilon$


In [18]:
Base.show(io::IO, x::D) = print(io, x.f[1], " + ", x.f[2], " ϵ")

In [19]:
# Add the last 2 rules to type system

import Base: -, *

# overload: 
-(x::D, y::D) = D(x.f .- y.f)
*(x::D, y::D) = D((x.f[1] * y.f[1], (y.f[1] * x.f[2] + x.f[1] * y.f[2])))

* (generic function with 402 methods)

In [20]:
D((1, 0))

1.0 + 0.0 ϵ

In [21]:
D((0, 2))^2 # should be zero!

0.0 + 0.0 ϵ

In [22]:
D((2, 1))^2

4.0 + 4.0 ϵ

In [23]:
ϵ = D((0, 1))
@code_native(ϵ^2)

	.text
; ┌ @ intfuncs.jl:222 within `^'
	pushq	%rbx
	subq	$16, %rsp
	movq	%rdi, %rbx
	movabsq	$power_by_squaring, %rax
	movq	%rsp, %rdi
	callq	*%rax
	vmovups	(%rsp), %xmm0
	vmovups	%xmm0, (%rbx)
	movq	%rbx, %rax
	addq	$16, %rsp
	popq	%rbx
	retq
	nopl	(%rax)
; └


In [24]:
ϵ * ϵ

0.0 + 0.0 ϵ

In [25]:
ϵ^2

0.0 + 0.0 ϵ

In [26]:
1 / (1 + ϵ) # exact power series 1 - ϵ + ϵ^2 - ϵ^3 ...

1.0 + -1.0 ϵ

In [27]:
(1 + ϵ)^5 # note it just works (we did not train powers)

1.0 + 5.0 ϵ

## Generalization to arbitrary roots

In [28]:
function nth_root(x, n=2; t=1, p=10)
    for i = 1:p
        t += (x / t^(n-1) - t) / n
    end
    t
end

nth_root (generic function with 2 methods)

In [29]:
nth_root(2, 3), ∛2  # copied from https://www.alt-codes.net/root-symbols

(1.2599210498948732, 1.2599210498948732)

In [30]:
nth_root(2 + ϵ, 3)

1.2599210498948732 + 0.20998684164914552 ϵ

In [31]:
nth_root(7, 12), 7^(1/12)

(1.1760474285795146, 1.1760474285795146)

In [32]:
x = 2.0
nth_root( x + ϵ), ∛x, 1/x*(2/3)/3

(1.414213562373095 + 0.35355339059327373 ϵ, 1.2599210498948732, 0.1111111111111111)

## Forward Diff

Now that you understand it, you can use the official package.


In [33]:
# Pkg.add("ForwardDiff)
using ForwardDiff

In [34]:
ForwardDiff.derivative(sqrt, 2)

0.35355339059327373

In [35]:
ForwardDiff.derivative(√, 2)

0.35355339059327373

In [36]:
ForwardDiff.derivative(babylonian, 2)

0.35355339059327373

In [37]:
@which ForwardDiff.derivative(√, 2)

## Close Look at Convergence with big floats

the $-log10$ gives the number of correct digits. Watch the quadratic convergence right before your eyes.

In [79]:
setprecision(3000)

# round.(Float64.(log10.([babylonian(BigFloat(2), n=k) for k=1:10] - √BigFloat(2))), 3)
round.(Float64.(log10.([babylonian(BigFloat(2), n=k) - √BigFloat(2.) for k=1:10])), sigdigits=6)

10-element Array{Float64,1}:
   -1.06658
   -2.61028
   -5.67287
  -11.7973 
  -24.0461 
  -48.5437 
  -97.539  
 -195.53   
 -391.511  
 -783.473  

In [80]:
struct D1{T} <: Number # D is a <function, derivative> pair
    f::Tuple{T, T}
end

In [81]:
z = D((2., 1.))
z1 = D1((BigFloat(2.), BigFloat(1.)))

D1{BigFloat}((2.0, 1.0))

In [82]:
import Base: +, /, convert, promote_rule

# overload: 
+(x::D1, y::D1) = D1(x.f .+ y.f)
/(x::D1, y::D1) = D1((x.f[1] / y.f[1], (y.f[1] * x.f[2] - x.f[1] * y.f[2]) / y.f[1]^2))
convert(::Type{D1{T}}, x::Real) where {T} = D1((convert(T, x), zero(T))) 
promote_rule(::Type{D1{T}}, ::Type{S}) where {T, S <: Number} = D1{promote_type(T, S)} 

promote_rule (generic function with 172 methods)

In [83]:
A = randn(3, 3)

3×3 Array{Float64,2}:
 -0.207265   1.51992   -0.917178
  0.527821  -0.65434   -0.299627
  1.02614    0.840293  -1.70233 

In [84]:
x = randn(3)

3-element Array{Float64,1}:
  0.7878839266680854
 -0.3854946973777788
 -0.8610926766776773

In [85]:
ForwardDiff.gradient(x -> x'A*x, x)

3-element Array{Float64,1}:
 -1.2098230999591228
  1.652310614650758 
  2.8091580552435125

In [86]:
(A + A') * x

3-element Array{Float64,1}:
 -1.2098230999591226
  1.652310614650758 
  2.8091580552435125

In [90]:
using LinearAlgebra

n = 4
Strang = SymTridiagonal(2 * ones(n), - ones(n-1))

4×4 SymTridiagonal{Float64,Array{Float64,1}}:
  2.0  -1.0    ⋅     ⋅ 
 -1.0   2.0  -1.0    ⋅ 
   ⋅   -1.0   2.0  -1.0
   ⋅     ⋅   -1.0   2.0