 * Copyright 2024 Xue_Lexiang
 * Licensed under MIT (https://github.com/xuelx1/LearnQC/LISENCE)

##  Shor algorithm. 3n+2 qubit version.
###  Given a number n, return a non-trivial factor of n. 

### 1.Imports and Set parameters

In [1]:
from math import sqrt, isqrt, ceil, gcd, log2
from random import randint
from sympy import isprime
from fractions import Fraction

In [2]:
max_times = 10 # max number of times to try qpe_period_finding

runtime = False # False: use simulator; True: use real device

if runtime:
    from qiskit_ibm_runtime import QiskitRuntimeService
    server = QiskitRuntimeService(channel='ibm_quantum'
                              , token='5a4183f75ee6b42611883d88d91c3c60fce0dbe91f4b8b73b0b508671732ef1ca42fef964ca37661240d5fe8b4038907fc0cb1eedd7e4ed7280e89506db7622c'
    )
    backend = server.backend(name = 'ibm_kyiv')
else:
    from qiskit_aer import AerSimulator
    backend = AerSimulator()

### 2.Order_Finding Algorithm (Quantum Part)

#### Quantum Part of Shor's Algorithm is implemented in shor_quantum.py, based on https://github.com/kazawai/shor_qiskit. 

In [3]:
from shor_quantum import shor_quantum
print("Import completed.")

Import completed.


### 3. Shor Algorithm (Classical Part)

In [4]:
def find_factor(n, backend):
    if n % 2==0: 
        return 2
    for i in range(2, ceil(log2(n)+1)):
        root = n ** (1 / i) 
        if root == int(root):
            return int(root)
    count = 0
    while count <= max_times:
        a = randint(2, n-1)
        N = ceil(log2(n))
        if gcd(a, n) != 1:
            return gcd(a, n)
        r = shor_quantum(a, n, backend)
        if r == 0 :
            continue
        if pow(a, int(r/2), n) != -1:
            s1 = gcd(pow(a, int(r/2), n) - 1, n)
            s2 = gcd(pow(a, int(r/2), n) + 1, n)
            if s1 > 1 and s1 < N:
                return s1
            if s2 > 1 and s2 < N:
                return s2
        count +=1
        print(f"Trying {count} of {max_times}, a == {a}, r == {r}.")



### 4. Factorization

In [5]:
N = 45        

In [6]:
factors = []

def find_inprime(factors):
    for fac_idx in range(len(factors)):
        if not isprime(factors[fac_idx]):
            return fac_idx
    return -1

factors.append(N)
while find_inprime(factors) != -1:
    print(factors)
    fac_idx = find_inprime(factors)
    fac = factors.pop(fac_idx)
    new_fac = find_factor(fac, backend)
    factors.append(new_fac)
    factors.append(int(fac/new_fac))
print(factors)


[45]
[3, 15]
[3, 5, 3]
