# Understanding Montgomery Reduction

## 1. Introduction: The Problem with Modular Division

Modular arithmetic often involves computing remainders: for integers `a` and `n`, we find `q` and `r` such that:

$$
a = qn + r,\quad \text{where} \quad 0 \leq r < |n|
$$

Here, `r = a mod n`. While this is straightforward for small numbers, it becomes inefficient at scale, especially in cryptographic computations with large integers.

For example, to compute:

$$
(12 \times 15) \mod 7 = ((12 \mod 7) \times (15 \mod 7)) \mod 7 \\
= (5 \times 1) \mod 7 \\
= 5
$$

We still need to perform `mod` operations (i.e., division), which are costly for large numbers. Since cryptographic systems rely heavily on modular multiplication, this repeated division becomes a bottleneck.

**Montgomery reduction** addresses this by avoiding direct division, making modular multiplication more efficient for large integers.

## 2. The Core Idea: A New Domain for Faster Math

Montgomery reduction speeds up modular arithmetic by moving calculations into a special **Montgomery domain**, avoiding costly division by `n`.

This is done using a new modulus `R`, typically a power of 2 (like  $2^{32}$ or $2^{64}$).

### Why Use a Power of 2?

Because computers handle powers of 2 efficiently:

- **Division by `R`** → simple **right shift**
- **Modulo `R`** → fast **bitwise AND**

This makes reductions much faster than regular division.

## 3. The Montgomery Algorithm: Setup and Multiplication

With `R` as a power of 2, we set up the Montgomery system through a one-time preparation:

### Setup Phase

1. **Choose `R`**:  
   A power of 2 greater than `n`, enabling fast bitwise operations.

2. **Compute $R^{-1}$**:  
   The modular inverse of `R` such that:

   $$
   R \cdot R^{-1} \equiv 1 \pmod{n}
   $$

3. **Compute `n'`**:  
   The modular inverse of `-n` modulo `R`, satisfying:

   $$
   -n \cdot n' \equiv 1 \pmod{R}
   $$

The values $R^{-1}$ and `n'` are precomputed once and used in all Montgomery operations.


### Conversion to Montgomery Form

To work in the Montgomery domain, convert a number `a` to its Montgomery form `a'`:

$$
a' = a \cdot R \pmod{n}
$$

This step is a one-time, regular modular multiplication—our "entry fee" to faster computation.


### Multiplication in the Montgomery Domain

Given `a'` and `b'` in Montgomery form, their product `c'` is:

$$
c' = a' \cdot b' \cdot R^{-1} \pmod{n}
$$

While this still looks like it needs division by `n`, the **Montgomery Reduction (`REDC`)** algorithm efficiently handles this without actual division.

## 4. The REDC Algorithm and Final Conversion

To compute the Montgomery product  
`c' = a' · b' · R⁻¹ mod(n)`,  
we use the **`REDC`** function, which efficiently calculates  
$T \cdot R^{-1} \mod n$  
for $T = a' \cdot b'$.

### The `REDC` Algorithm

1. **Compute `m`:**

   $$
   m = T \bmod R \cdot n’ \bmod R
   $$

2. **Compute `t`:**

   $$
   t = \frac{T + m \cdot n}{R}
   $$

3. **Final correction:**

   $$
   \text{If } t \geq n, \quad \text{then } t = t - n
   $$

4. **Return `t` as the result.**

The speed boost comes from dividing by `R` (a power of 2), which is just a fast bit shift.


### Conversion Back to Standard Form

After all operations, the result `c'` is still in Montgomery form.  
To convert back, apply one last `REDC`:

$$
c = \text{REDC}(c')
$$

In [15]:
class Montgomery:
    def __init__(self, n):
        if n % 2 == 0:
            raise ValueError("Modulus n must be odd.")
        
        self.n = n
        self.logR = n.bit_length()
        self.R = 1 << self.logR
        self.R_mask = self.R - 1
        
        n_inv_R = self._modinv(self.n, self.R)
        self.n_prime = self.R - n_inv_R

    def _egcd(self, a, b):
        if a == 0:
            return (b, 0, 1)
        g, y, x = self._egcd(b % a, a)
        return (g, x - (b // a) * y, y)

    def _modinv(self, a, m):
        g, x, y = self._egcd(a, m)
        if g != 1:
            raise ValueError('Modular inverse does not exist')
        return x % m
        
    def _reduce(self, T):
        m = ((T & self.R_mask) * self.n_prime) & self.R_mask
        t = (T + m * self.n) >> self.logR
        
        if t >= self.n:
            return t - self.n
        else:
            return t

    def convert_in(self, x):
        return (x * self.R) % self.n

    def convert_out(self, x_mont):
        return self._reduce(x_mont)

    def multiply(self, a_mont, b_mont):
        T = a_mont * b_mont
        return self._reduce(T), T

# --- Usage Example ---
# 1. One-time setup
monty_system = Montgomery(n=13)
print(f"System Initialized for n={monty_system.n}")
print(f"Calculated R={monty_system.R}, n'={monty_system.n_prime}")
print("-" * 30)

# 2. Convert numbers to Montgomery form
a = 7
b = 8
a_mont = monty_system.convert_in(a)
b_mont = monty_system.convert_in(b)
print(f"{a} in Montgomery form is: {a_mont}")
print(f"{b} in Montgomery form is: {b_mont}")
print("-" * 30)

# 3. Perform multiplication in the Montgomery domain
product_mont, intermediate_T = monty_system.multiply(a_mont, b_mont)
print(f"Intermediate product T = {a_mont} * {b_mont} = {intermediate_T}")
print(f"REDC(T) -> Product in Montgomery form: {product_mont}")
print("-" * 30)

# 4. Convert the result back to a standard number
final_result = monty_system.convert_out(product_mont)
print(f"Final Result (after converting back): {final_result}")
print(f"Standard Check: (7 * 8) % 13 = {(7 * 8) % 13}")

System Initialized for n=13
Calculated R=16, n'=11
------------------------------
7 in Montgomery form is: 8
8 in Montgomery form is: 11
------------------------------
Intermediate product T = 8 * 11 = 88
REDC(T) -> Product in Montgomery form: 12
------------------------------
Final Result (after converting back): 4
Standard Check: (7 * 8) % 13 = 4


---

### A Quick Clarification: Why Does Montgomery Seem Slower at First?

> **Question:**  
> Montgomery starts with divisions like 112 mod 13 or 128 mod 13, which look slower than a simple 56 mod 13. So why is it still preferred—and actually faster—when working with large numbers?

For a single, small calculation, the setup cost of Montgomery reduction makes it **slower** than the standard method. 

The performance boost isn’t for one-off calculations; it’s for **chains of multiplications** performed with the same modulus, which is extremely common in cryptography. The classic use case is **modular exponentiation** (`a^e mod n`), which is the core of RSA and the example we will explore next.

In [16]:
# Re-using our clean Montgomery class
class Montgomery:
    def __init__(self, n):
        if n % 2 == 0: raise ValueError("Modulus n must be odd.")
        self.n = n
        self.logR = n.bit_length()
        self.R = 1 << self.logR
        self.R_mask = self.R - 1
        n_inv_R = self._modinv(self.n, self.R)
        self.n_prime = self.R - n_inv_R

    def _egcd(self, a, b):
        if a == 0: return (b, 0, 1)
        g, y, x = self._egcd(b % a, a)
        return (g, x - (b // a) * y, y)

    def _modinv(self, a, m):
        g, x, y = self._egcd(a, m)
        if g != 1: raise ValueError('Modular inverse does not exist')
        return x % m
        
    def _reduce(self, T):
        m = ((T & self.R_mask) * self.n_prime) & self.R_mask
        t = (T + m * self.n) >> self.logR
        return t - self.n if t >= self.n else t

    def convert_in(self, x):
        return (x * self.R) % self.n

    def convert_out(self, x_mont):
        return self._reduce(x_mont)

    def multiply(self, a_mont, b_mont):
        return self._reduce(a_mont * b_mont)

# --- Method 1: Standard Exponentiation Trace  ---
def trace_standard_pow_clean(base, exp, mod):
    print("--- Starting Standard Modular Exponentiation ---")
    expensive_ops = 0
    res = 1
    binary_exp = bin(exp)[2:]
    print(f"Executing for exponent {exp} (binary: {binary_exp})\n")
    
    for i, bit in enumerate(binary_exp):
        # Squaring step
        res_old = res
        res = (res * res) % mod
        expensive_ops += 1
        print(f"Step {i+1} (Square): ({res_old}*{res_old}) % {mod} -> {res} (EXPENSIVE)")
        
        # Multiplication step if bit is 1
        if bit == '1':
            res_old = res
            res = (res * base) % mod
            expensive_ops += 1
            print(f"Step {i+1} (Mult):   ({res_old}*{base}) % {mod} -> {res} (EXPENSIVE)")
            
    print(f"\nFinal Result: {res}")
    print(f"Total Expensive (mod n) Operations: {expensive_ops}\n")

# --- Method 2: Montgomery Exponentiation Trace ---
def trace_montgomery_pow_clean(base, exp, n):
    print("--- Starting Montgomery Modular Exponentiation ---")
    expensive_ops = 0
    
    # 1. Setup & Conversion
    print("Step 1: Setup & Initial Conversion")
    monty = Montgomery(n)
    res_mont = monty.convert_in(1)
    expensive_ops += 1
    base_mont = monty.convert_in(base)
    expensive_ops += 1
    print(f"  - Converting 1 -> {res_mont} (EXPENSIVE OP #{expensive_ops-1})")
    print(f"  - Converting {base} -> {base_mont} (EXPENSIVE OP #{expensive_ops})\n")

    # 2. Main Loop
    print("Step 2: Main loop with FAST operations")
    binary_exp = bin(exp)[2:]
    for i, bit in enumerate(binary_exp):
        # Squaring step
        res_old = res_mont
        res_mont = monty.multiply(res_mont, res_mont)
        print(f"Loop {i+1} (Square): REDC({res_old}*{res_old}) -> {res_mont} (FAST)")

        # Multiplication step if bit is 1
        if bit == '1':
            res_old = res_mont
            res_mont = monty.multiply(res_mont, base_mont)
            print(f"Loop {i+1} (Mult):   REDC({res_old}*{base_mont}) -> {res_mont} (FAST)")

    # 3. Final Conversion
    final_res = monty.convert_out(res_mont)
    print(f"\nStep 3: Final conversion -> REDC({res_mont}) -> {final_res} (FAST)")
    
    print(f"\nFinal Result: {final_res}")
    print(f"Total Expensive (mod n) Operations: {expensive_ops}\n")

# --- Run the cleaned traces ---
base, exp, mod = 5, 10, 13
trace_standard_pow_clean(base, exp, mod)
print("="*50)
trace_montgomery_pow_clean(base, exp, mod)

--- Starting Standard Modular Exponentiation ---
Executing for exponent 10 (binary: 1010)

Step 1 (Square): (1*1) % 13 -> 1 (EXPENSIVE)
Step 1 (Mult):   (1*5) % 13 -> 5 (EXPENSIVE)
Step 2 (Square): (5*5) % 13 -> 12 (EXPENSIVE)
Step 3 (Square): (12*12) % 13 -> 1 (EXPENSIVE)
Step 3 (Mult):   (1*5) % 13 -> 5 (EXPENSIVE)
Step 4 (Square): (5*5) % 13 -> 12 (EXPENSIVE)

Final Result: 12
Total Expensive (mod n) Operations: 6

--- Starting Montgomery Modular Exponentiation ---
Step 1: Setup & Initial Conversion
  - Converting 1 -> 3 (EXPENSIVE OP #1)
  - Converting 5 -> 2 (EXPENSIVE OP #2)

Step 2: Main loop with FAST operations
Loop 1 (Square): REDC(3*3) -> 3 (FAST)
Loop 1 (Mult):   REDC(3*2) -> 2 (FAST)
Loop 2 (Square): REDC(2*2) -> 10 (FAST)
Loop 3 (Square): REDC(10*10) -> 3 (FAST)
Loop 3 (Mult):   REDC(3*2) -> 2 (FAST)
Loop 4 (Square): REDC(2*2) -> 10 (FAST)

Step 3: Final conversion -> REDC(10) -> 12 (FAST)

Final Result: 12
Total Expensive (mod n) Operations: 2

