In [18]:
import time
from tabulate import tabulate

In [19]:
class BinaryNumber:
    """ done """
    def __init__(self, n):
        self.decimal_val = n
        #binary representation of the input number                
        self.binary_vec = list('{0:b}'.format(n)) 
        
    def __repr__(self):
        return('decimal=%d binary=%s' % (self.decimal_val, ''.join(self.binary_vec)))

# some useful utility functions to manipulate bit vectors
def binary2int(binary_vec): 
    '''list of bits back to an integer'''
    if len(binary_vec) == 0:
        return BinaryNumber(0)
    
    return BinaryNumber(int(''.join(binary_vec), 2))

def split_number(vec):
    '''splitter that will chop our array in half'''
    return (binary2int(vec[:len(vec)//2]),
            binary2int(vec[len(vec)//2:]))

def bit_shift(number, n):
    '''shifts an input left by n bits - essentially appends n 0 bits to the end of the bit array
       this makes the value much larger - essentially squaring it'''
    return binary2int(number.binary_vec + ['0'] * n)
    
def pad(x,y):

    # pad with leading 0 if x/y have different number of bits
    #leading 0s have no effect on the value, and enable no interference length manipulation
    if len(x) < len(y):
        x = ['0'] * (len(y)-len(x)) + x


    elif len(y) < len(x):
        y = ['0'] * (len(x)-len(y)) + y

    # pad with leading 0 if not even number of bits
    #you want it to be even so the split is equivalent on both sides 
    #this is neccesary for the operation to work
    if len(x) % 2 != 0:
        x = ['0'] + x
        y = ['0'] + y

    return x,y

#### samples that use the provided functionality and display the binary

In [20]:
def quadratic_multiply(x, y):
    #base case: one of them is a single bit - return to break recursion 
    if len(x.binary_vec) == 1 or len(y.binary_vec) == 1:
        return list(bin(x.decimal_val * y.decimal_val)[2:])
    
    #we need to continue recursion
    else:
        #we need to continue recursion, pad to get an even length before we split them
        x_array, y_array = pad(x.binary_vec, y.binary_vec)

        #split these into 4-quadrants and multiply all combinations 
        #this is where the quadratic runtime comes from - have to do 4 per iteration
        x_left, x_right = split_number(x_array)
        y_left, y_right = split_number(y_array)
    
        #left half of x and y
        product1 = binary2int(quadratic_multiply(x_left, y_left))
    
        #right half of x and y
        product2 = binary2int(quadratic_multiply(x_right, y_right))
    
        #left half of x and right half of y
        product3 = binary2int(quadratic_multiply(x_left, y_right))
    
        #right half of x and left half of y
        product4 = binary2int(quadratic_multiply(x_right, y_left))

        n = len(x_array)

        #(xL ​yL​) ⋅ 2n + (xL ​yR ​+ xR ​yL​) ⋅ 2n/2 + (xR ​yR​)
        result = (
        bit_shift(product1, n).decimal_val + product2.decimal_val +
        bit_shift(BinaryNumber(product3.decimal_val + product4.decimal_val), n // 2).decimal_val
        )
        
    return list(bin(result)[2:])

print(binary2int(quadratic_multiply(BinaryNumber(2), BinaryNumber(2))))
print(binary2int(quadratic_multiply(BinaryNumber(3), BinaryNumber(7))))

decimal=4 binary=100
decimal=21 binary=10101


In [None]:
def subquadratic_multiply(x, y):
    #base case: one of them is a single bit - return to break recursion 
    if len(x.binary_vec) == 1 or len(y.binary_vec) == 1:
        return list(bin(x.decimal_val * y.decimal_val)[2:])
    
    #we need to continue recursion
    else:
        #pad to get an even length before we split them again
        x_array, y_array = pad(x.binary_vec, y.binary_vec)

        #split these into 4-quadrants and multiply all combinations 
        #this is where the quadratic runtime comes from - have to do 4 per iteration
        x_left, x_right = split_number(x_array)
        y_left, y_right = split_number(y_array)
    
        #left half of x and y
        product1 = binary2int(subquadratic_multiply(x_left, y_left))
    
        #right half of x and y
        product2 = binary2int(subquadratic_multiply(x_right, y_right))

        #Karatsuba's trick of instead adding these together
        #this makes 3 operations instead of 4, hence subquadratic
        product3 = binary2int(subquadratic_multiply(
        BinaryNumber(x_left.decimal_val + x_right.decimal_val),
        BinaryNumber(y_left.decimal_val + y_right.decimal_val)
        ))

        n = len(x_array)

        #p1 * 2^n + (p3 - p1 - p2) * 2^(n/2) + p2
        result = (bit_shift(product1, n).decimal_val + product2.decimal_val +
        bit_shift( BinaryNumber(product3.decimal_val - product1.decimal_val - product2.decimal_val), n // 2).decimal_val
        )
        
    return list(bin(result)[2:])

print(binary2int(subquadratic_multiply(BinaryNumber(2), BinaryNumber(2))))
print(binary2int(subquadratic_multiply(BinaryNumber(3), BinaryNumber(7))))

decimal=4 binary=100
decimal=21 binary=10101


#### examples to go into the python file with assertions - had to make some changes, hope you don't mind

In [21]:
def quadratic_multiply(x, y):
    #base case: one of them is a single bit - return to break recursion 
    if len(x.binary_vec) == 1 or len(y.binary_vec) == 1:
        return x.decimal_val * y.decimal_val
    
    #we need to continue recursion
    else:

        #we need to continue recursion, pad to get an even length before we split them
        x_array, y_array = pad(x.binary_vec, y.binary_vec)

        #split these into 4-quadrants and multiply all combinations 
        #this is where the quadratic runtime comes from - have to do 4 per iteration
        x_left, x_right = split_number(x_array)
        y_left, y_right = split_number(y_array)

        n = len(x_array)
    
        #left half of x and y
        product1 = quadratic_multiply(x_left, y_left)
        p1 = bit_shift(BinaryNumber(product1), n) 
    
        #right half of x and y
        product2 = quadratic_multiply(x_right, y_right)
        p2 = BinaryNumber(product2)   
    
        #left half of x and right half of y
        product3 = quadratic_multiply(x_left, y_right)
        
        #right half of x and left half of y
        product4 = quadratic_multiply(x_right, y_left)

        cross = BinaryNumber(product3 + product4)
        p3 = bit_shift(cross, n // 2)

        #(xL ​yL​) ⋅ 2n + (xL ​yR ​+ xR ​yL​) ⋅ 2n/2 + (xR ​yR​)
        return p1.decimal_val + p2.decimal_val + p3.decimal_val
    
print(quadratic_multiply(BinaryNumber(2), BinaryNumber(2)))

4


In [24]:
def subquadratic_multiply(x, y):
    #base case: one of them is a single bit - return to break recursion 
    if len(x.binary_vec) == 1 or len(y.binary_vec) == 1:
        return x.decimal_val * y.decimal_val
    
    #we need to continue recursion
    else:
        #pad to get an even length before we split them again
        x_array, y_array = pad(x.binary_vec, y.binary_vec)

        #split these into 4-quadrants and multiply all combinations 
        #this is where the quadratic runtime comes from - have to do 4 per iteration
        x_left, x_right = split_number(x_array)
        y_left, y_right = split_number(y_array)

        n = len(x_array)
    
        #left half of x and y
        product1 = subquadratic_multiply(x_left, y_left)
        p1 = bit_shift(BinaryNumber(product1), n) 
    
        #right half of x and y
        product2 = subquadratic_multiply(x_right, y_right)
        p2 = BinaryNumber(product2)  

        #Karatsuba's trick of instead adding these together
        #this makes 3 operations instead of 4, hence subquadratic
        product3 = subquadratic_multiply(
        BinaryNumber(x_left.decimal_val + x_right.decimal_val),
        BinaryNumber(y_left.decimal_val + y_right.decimal_val)
        )

        cross = product3 - product1 - product2
        p3 = bit_shift(BinaryNumber(cross), n // 2) 

    return p1.decimal_val + p2.decimal_val + p3.decimal_val

In [None]:
# some timing functions here that will make comparisons easy    
def time_multiply(x, y, f):
    start = time.time()
    # multiply two numbers x, y using function f
    f(x,y)
    return (time.time() - start)*1000

def compare_multiply():
    res = []
    for n in [10,100,1000,10000,100000,1000000,10000000,100000000,1000000000]:
        qtime = time_multiply(BinaryNumber(n), BinaryNumber(n), quadratic_multiply)
        subqtime = time_multiply(BinaryNumber(n), BinaryNumber(n), subquadratic_multiply)        
        res.append((n, qtime, subqtime))
    print_results(res)


def print_results(results):
    print("\n")
    print(
        tabulate(
            results,
            headers=['n', 'quadratic', 'subquadratic'],
            floatfmt=".3f",
            tablefmt="github"))

In [None]:
n = 1000000
runtime = time_multiply(BinaryNumber(n), BinaryNumber(n), quadratic_multiply)
print(f"quadratic_multiply({n}, {n}) took {runtime:.3f} ms")

n = 1000000
runtime = time_multiply(BinaryNumber(n), BinaryNumber(n), subquadratic_multiply)
print(f"subquadratic_multiply({n}, {n}) took {runtime:.3f} ms")

quadratic_multiply(1000000, 1000000) took 0.382 ms
subquadratic_multiply(1000000, 1000000) took 0.369 ms


In [None]:
compare_multiply()



|          n |   quadratic |   subquadratic |
|------------|-------------|----------------|
|         10 |       0.055 |          0.038 |
|        100 |       0.050 |          0.073 |
|       1000 |       0.106 |          0.190 |
|      10000 |       0.103 |          0.204 |
|     100000 |       0.191 |          0.269 |
|    1000000 |       0.333 |          0.329 |
|   10000000 |       0.249 |          0.365 |
|  100000000 |       0.430 |          0.594 |
| 1000000000 |       0.494 |          0.760 |


#### Did a little google because this result confused me, can try using bit length instead of pure int size 

#### This will help us nullify the implementation overhead so as to show the true performance at scale 

In [None]:
import random 

def compare_multiply():
    results = []
    #test based on bit length with random numbers
    for bits in [8, 16, 32, 64, 128, 256, 512, 1024]:
        #generate random numbers based on bit length
        #https://docs.python.org/3/library/random.html#:~:text=random.getrandbits,arbitrarily%20large%20ranges.
        n1 = random.getrandbits(bits)
        n2 = random.getrandbits(bits)
        bn1, bn2 = BinaryNumber(n1), BinaryNumber(n2)

        qtime = time_multiply(bn1, bn2, quadratic_multiply)
        subqtime = time_multiply(bn1, bn2, subquadratic_multiply)

        results.append((bits, qtime, subqtime))

    print_results(results)

compare_multiply()



|    n |   quadratic |   subquadratic |
|------|-------------|----------------|
|    8 |       0.030 |          0.035 |
|   16 |       0.180 |          0.178 |
|   32 |       0.529 |          0.476 |
|   64 |       2.125 |          1.563 |
|  128 |      12.234 |          5.662 |
|  256 |      38.567 |         12.654 |
|  512 |     148.819 |         40.263 |
| 1024 |     553.245 |        118.540 |
