<a href="https://colab.research.google.com/github/sharathchandran2001/GeneralUtils/blob/main/Shor_RSA21_Simulation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🔐 Simulated Shor's Algorithm: Factoring a Small RSA-like Key (N = 21)

This notebook demonstrates how to use PennyLane to simulate Shor's algorithm and factor a small RSA-style number **N = 21 (3 × 7)**.

We simulate the quantum process of period finding and use classical post-processing to extract the original prime factors.

In [None]:
!pip install pennylane numpy matplotlib

In [None]:
# Re-import libraries and redefine the complete program after code execution reset
import pennylane as qml
import numpy as np
from fractions import Fraction
from math import gcd
import time
import random

# -------------------------------------
# Utility: Post-processing to extract factors
# -------------------------------------
def get_factors(a, N, r):
    if r % 2 != 0:
        print("⚠️ Period r is odd. Skipping.")
        return None
    plus = pow(a, r // 2) + 1
    minus = pow(a, r // 2) - 1
    f1 = gcd(plus, N)
    f2 = gcd(minus, N)
    print(f"Trying GCD({plus},{N}) = {f1}, GCD({minus},{N}) = {f2}")
    if f1 == 1 or f2 == 1 or f1 == N or f2 == N:
        return None
    return f1, f2

# -------------------------------------
# Simulate Shor's Algorithm
# -------------------------------------
def run_shor_sim(N, a=None, n_count=20):  # Doubled qubits
    total_wires = n_count + 6
    dev = qml.device('default.qubit', wires=total_wires, shots=1)

    if a is None:
        while True:
            a = random.randint(2, N - 1)
            if gcd(a, N) == 1:
                break

    def controlled_modular_exponentiation():
        for i in range(n_count):
            exponent = 2 ** i
            angle = (2 * np.pi * pow(a, exponent, N)) / N
            qml.ctrl(qml.RZ, control=i)(angle, wires=n_count)

    @qml.qnode(dev)
    def shor_circuit():
        for i in range(n_count):
            qml.Hadamard(wires=i)
        controlled_modular_exponentiation()
        for i in range(n_count // 2):
            qml.SWAP(wires=[i, n_count - i - 1])
        for j in range(n_count):
            qml.Hadamard(wires=j)
            for k in range(j + 1, n_count):
                angle = -np.pi / (2 ** (k - j))
                qml.ctrl(qml.RZ, control=k)(angle, wires=j)
        return qml.sample(wires=range(n_count))

    start = time.time()
    measurement = shor_circuit()
    bitstring = ''.join(str(b) for b in measurement)
    y = int(bitstring, 2)
    phase = y / (2 ** n_count)
    frac = Fraction(phase).limit_denominator(N)
    r = frac.denominator
    factors = get_factors(a, N, r)
    elapsed = time.time() - start

    return {
        "a": a,
        "y": y,
        "phase": phase,
        "r": r,
        "factors": factors,
        "elapsed_time": elapsed,
        "bitstring": bitstring
    }

# -------------------------------------
# Run simulation with retry logic
# -------------------------------------
N = 1991  # Simulated RSA number: 11 * 181
max_attempts = 10000
attempt_times = []
success_result = None
start_total = time.time()

for attempt in range(1, max_attempts + 1):
    print(f"\n🔁 Attempt {attempt} to factor N = {N}")
    start_iter = time.time()
    results = run_shor_sim(N)
    iter_time = time.time() - start_iter
    attempt_times.append(iter_time)

    print("Randomly selected a =", results["a"])
    print("Measurement bitstring =", results["bitstring"])
    print("Measured integer y =", results["y"])
    print("Estimated phase = {:.6f}".format(results["phase"]))
    print("Estimated period r =", results["r"])
    print("Elapsed time = {:.4f} seconds".format(results["elapsed_time"]))

    if results["factors"]:
        success_result = results
        break

end_total = time.time()
total_duration = end_total - start_total
average_iter_time = np.mean(attempt_times)

# -------------------------------------
# Final Summary
# -------------------------------------
print("\n===============================")
if success_result:
    print(f"✅ SUCCESS after {attempt} attempts")
    print(f"🔐 N = {N} factored into: {success_result['factors']}")
    print(f"🕒 Total time to find solution: {total_duration:.4f} seconds")
    print(f"⏱️ Average time per attempt: {average_iter_time:.4f} seconds")
    print(f"🧠 Total qubits used: {16 + 6} (Counting: 16, Extra: 6)")
else:
    print("❌ Failed to factor the number after max attempts.")
    print(f"⏱️ Total time: {total_duration:.4f} seconds")
    print(f"⏱️ Average attempt time: {average_iter_time:.4f} seconds")
print("===============================")



🔁 Attempt 1 to factor N = 1991
⚠️ Period r is odd. Skipping.
Randomly selected a = 659
Measurement bitstring = 00001111111101111100
Measured integer y = 65404
Estimated phase = 0.062374
Estimated period r = 497
Elapsed time = 7.2426 seconds

🔁 Attempt 2 to factor N = 1991
⚠️ Period r is odd. Skipping.
Randomly selected a = 786
Measurement bitstring = 00000000000011000000
Measured integer y = 192
Estimated phase = 0.000183
Estimated period r = 1
Elapsed time = 6.6336 seconds

🔁 Attempt 3 to factor N = 1991
⚠️ Period r is odd. Skipping.
Randomly selected a = 813
Measurement bitstring = 10111010100110100111
Measured integer y = 764327
Estimated phase = 0.728919
Estimated period r = 1767
Elapsed time = 9.0064 seconds

🔁 Attempt 4 to factor N = 1991
⚠️ Period r is odd. Skipping.
Randomly selected a = 805
Measurement bitstring = 00111110010000001000
Measured integer y = 254984
Estimated phase = 0.243172
Estimated period r = 1135
Elapsed time = 6.6818 seconds

🔁 Attempt 5 to factor N = 1991


In [None]:
# Install dependencies (uncomment if running for the first time)
# !pip install pennylane numpy matplotlib
import pennylane as qml
import numpy as np
from fractions import Fraction
from math import gcd
import time
import random


In [None]:
# Utility: Classical post-processing to extract factors from period
def get_factors(a, N, r):
    if r % 2 != 0:
        return None
    plus = pow(a, r // 2) + 1
    minus = pow(a, r // 2) - 1
    f1 = gcd(plus, N)
    f2 = gcd(minus, N)
    if f1 == 1 or f2 == 1 or f1 == N or f2 == N:
        return None
    return f1, f2


In [None]:
# Simulated Shor's Algorithm for small RSA-like key
def run_shor_sim(N, a=None, n_count=8):
    total_wires = n_count + 6
    dev = qml.device('default.qubit', wires=total_wires, shots=1)

    if a is None:
        while True:
            a = random.randint(2, N-1)
            if gcd(a, N) == 1:
                break

    def controlled_modular_exponentiation():
        for i in range(n_count):
            exponent = 2 ** i
            angle = (2 * np.pi * pow(a, exponent, N)) / N
            qml.ctrl(qml.RZ, control=i)(angle, wires=n_count)

    @qml.qnode(dev)
    def shor_circuit():
        for i in range(n_count):
            qml.Hadamard(wires=i)
        controlled_modular_exponentiation()
        for i in range(n_count // 2):
            qml.SWAP(wires=[i, n_count - i - 1])
        for j in range(n_count):
            qml.Hadamard(wires=j)
            for k in range(j + 1, n_count):
                angle = -np.pi / (2 ** (k - j))
                qml.ctrl(qml.RZ, control=k)(angle, wires=j)
        return qml.sample(wires=range(n_count))

    start = time.time()
    measurement = shor_circuit()
    bitstring = ''.join(str(b) for b in measurement)
    y = int(bitstring, 2)
    phase = y / (2 ** n_count)
    frac = Fraction(phase).limit_denominator(N)
    r = frac.denominator
    factors = get_factors(a, N, r)
    elapsed = time.time() - start

    return {
        "a": a,
        "y": y,
        "phase": phase,
        "r": r,
        "factors": factors,
        "elapsed_time": elapsed,
        "bitstring": bitstring
    }


In [None]:
# Define smallest RSA-like number: N = 21 (3 x 7)
N = 21
results = run_shor_sim(N)

# Display results
print("Randomly selected a =", results["a"])
print("Measurement bitstring =", results["bitstring"])
print("Measured integer y =", results["y"])
print("Estimated phase =", results["phase"])
print("Estimated period r =", results["r"])
print("Elapsed time = {:.4f} seconds".format(results["elapsed_time"]))

if results["factors"]:
    print("✅ Success! Factors of", N, "are:", results["factors"])
else:
    print("❌ Failed to find non-trivial factors.")
