# Shor's Algorithm

Shor's algorithm can find the factors of a large composite number in polynomial time with huge implications on cryptographic security. The crux of the "quantum" part of the algorithm involves finding the period of a modular exponentiation operation, which we explore below.

In [None]:
import numpy as np
import sympy

from cirq_qubitization.jupyter_tools import show_bloq
from cirq_qubitization.bloq_algos.shors import *
from cirq_qubitization.bloq_algos.shors.shors import *

## Example: Factor 13 * 17

For sufficiently small problem sizes, we can explore the mechanism of the algorithm classically. Here, we will use period finding to factor `N = 13 * 17`.

 - `N` is the composite number to factor
 - `n` is its bitsize

In [None]:
N = 13*17
n = int(np.ceil(np.log2(N)))
N, n

The first step is to pick a random guess to serve as the base of our exponentiation (for which we will find the period).

In [None]:
g = 8
g

### Modular exponentiation

The function $f(e) = g^e\ \text{mod}\ N$ will be periodic in $e$ because of the modulo operation. In particular, we can try subsequent values of $e$ to figure out how long it takes to cycle back to evaluating to $1$.

In [None]:
for e in range(20):
    f = (g ** e) % N
    
    star = ' *' if f == 1 else ''
    print(f'{e:5d} {f:5d}{star}')

The code above puts an asterisk when the function evaluates to 1. You can see that the function has a consistent period:

In [None]:
16-8, 8-0

In [None]:
period = 8

### Use the period to find factors

For large numbers, finding the period classically takes an exponential number of calls to the exponentiation function. Quantum computers can do it in constant time, see the references.

Once we have the period, we can use some numerical tricks to efficiently and classically find the two factors.

In [None]:
assert period %2 == 0
assert g**(period//2) != -1

half_period = g**(period//2)
p1 = half_period + 1
m1 = half_period - 1

assert (p1*m1) % N == 0

In [None]:
print(f'gcd{p1%N, N}, gcd{m1%N, N}')
import math
math.gcd(p1%N, N), math.gcd(m1%N, N)

## Bloqs

We'll do modular exponentiation like above with `g` and `N`. `exponent` is a $2n$ sized input and we will allocate a new $n$ sized register `x` which will contain the output.

In [None]:
bloq = ModExp.make_for_shor(big_n=N, g=g)
show_bloq(bloq)

We can simulate individual inputs for `exponent` to check the classical logic.

In [None]:
for e in range(20):
    f_ref = (g ** e) % N
    e, f_bloq  = bloq.call_classically(exponent=e)
    assert f_ref == f_bloq

print("Checks out!")

In [None]:
from sympy import Symbol
show_bloq(
    ModExp(base=Symbol('g'), mod=Symbol('N'), exp_bitsize=2*Symbol('n'), x_bitsize=Symbol('n'))
)

### Smaller example

In [None]:
N = 3*5
n = int(np.ceil(np.log2(N)))
g = 8

for e in range(20):
    f = (g ** e) % N
    star = ' *' if f == 1 else ''
    print(f'{e:5d} {f:5d}{star}')

In [None]:
bloq = ModExp.make_for_shor(big_n=N, g=g)
show_bloq(bloq)

In [None]:
import inspect
import textwrap
from IPython.display import Code

In [None]:
source = inspect.getsource(bloq.build_composite_bloq)
source = textwrap.dedent(source)
Code(source, language='python3')

In [None]:
for e in range(20):
    e, f = bloq.call_classically(exponent=e)
    
    star = ' *' if f == 1 else ''
    print(f'{e:5d} {f:5d}{star}')

In [None]:
from cirq_qubitization.quantum_graph.graphviz import ClassicalSimGraphDrawer
ClassicalSimGraphDrawer(bloq, {'exponent': 5}).get_svg()

In [None]:
cbloq = bloq.decompose_bloq()
show_bloq(cbloq)

In [None]:
ClassicalSimGraphDrawer(cbloq, {'exponent': 5}).get_svg()

In [None]:
cbloq.call_classically(exponent=5)

In [None]:
for ei in range(20):
    e, f = bloq.call_classically(exponent=ei)
    e2, f2 = cbloq.call_classically(exponent=ei)
    
    
    star = ' *' if f == 1 else ''
    
    print(f'{e:5d} {f:5d}{star:2s}  {e2:5d} {f2:5d}')

In [None]:
cmm = CtrlModMul(k=8, bitsize=n, mod=N)
print(cmm)
show_bloq(cmm.decompose_bloq())