# Reed Solomon Codes

The following code is an imlementation of Reed Solomon Codes in python using the bigint library provided by 
The link to the same is as follows -

Implemented by -
1. Siddharth Kothari (<b>IMT2021019</b>)
2. Sankalp Kothari (<b>IMT2021028</b>)


# System setup


To run on Google Colab, kindly go to Google Colab, click on File > Upload Notebook, and go the location where you saved this file. Click on the file and upload it.

<br>

To run the notebook on your local machine (on Ubuntu), kindly install Anaconda by followin the steps at the given link
https://docs.anaconda.com/free/anaconda/install/linux/

Once anaconda installation is done, run the following command on the terminal to install gmpy2.


In [None]:
conda install -c anaconda gmpy2

1. To run the following Jupyter Notebook on VSCode, just open it in VSCode. 
2. Select the anaconda python interpreter (the one marked as 'Conda') to run the notebook. 
3. To change the interpreter, open the Command Palette (Ctrl + Shift + p), and select the Python: select interpreter command, and set the interpreter as Anaconda.

# Code Implementation

In [4]:
# Code Implementation
from gmpy2 import is_prime, mpz, mpfr
import random
from math import floor


length = int(input("Enter bit length of message: "))
# M = 2**l
M = mpz(1<<length)

print(f"Max value of message: {M}")

a = int(input("Enter message: "))
if (a>M):
    print(f"Message can't be greater than {M}")
    exit()

print(f"message : {a}")

mu = float(input("Enter corruption factor: "))
if (mu < 0 or mu > 1):
    print("Corruption factor has to be >=0 and <=1")
    exit()

print(f"corruption factor : {mu}")

# all the functions and classes below use arguments which are of the type mpz
def multMod(a,b,n):
    return (a*b)%n

def binaryEGCD(a,b):

    r = a
    r_prime = b
    e = 0

    while (r%2 == 0 and r_prime%2 == 0):
        r = r>>1
        r_prime = r_prime>>1
        e+=1
    
    a_prime = r
    b_prime = r_prime

    s=1
    t=0
    s_prime = 0
    t_prime = 1
 
    while (r_prime > 0):
        while (r%2 == 0):
            r = r>>1
            if (s%2 == 0 and t%2 == 0):
                s = s>>1
                t = t>>1
            else:
                s = (s+b_prime)>>1
                t = (t-a_prime)>>1

        while (r_prime%2 == 0):
            r_prime = r_prime>>1
            if (s_prime%2 == 0 and t_prime%2 == 0):
                s_prime = s_prime>>1
                t_prime = t_prime>>1
            else:
                s_prime = (s_prime+b_prime)>>1
                t_prime = (t_prime-a_prime)>>1

        if (r_prime < r):
            r,r_prime = r_prime,r
            s,s_prime = s_prime,s
            t,t_prime = t_prime,t
        
        r_prime = r_prime-r
        s_prime = s_prime-s
        t_prime = t_prime-t
    
    return (r*(2**e), s, t)

def EGCD(a, b):

    rlist = []
    rlist.append(a)
    rlist.append(b)
    slist = []
    slist.append(1)
    slist.append(0)
    tlist = []
    tlist.append(0)
    tlist.append(1)

    r = a%b
    while (r != 0):
        q = a//b
        r = a%b

        rlist.append(r)
        s = slist[-2] - q*slist[-1]
        t = tlist[-2] - q*tlist[-1]
        slist.append(s)
        tlist.append(t)
         
        a = b
        b = r

    return (rlist,slist,tlist)


def getInverse(a, n):
    return binaryEGCD(a,n)[1]
    

class CRT :
    def __init__(self, N):
        self.N = N
        # N represents the list of primes used (n1, n2, ..., nk)

    def getPrimes(self):
        return self.N
    
    def getProduct(self):
        n = mpz(1)
        for i in self.N:
            n*=i
        return n
    
    def getCRTMap(self, a):
        # gets the CRT Map for a given a 
        CRTmap = []
        for i in range(len(self.N)):
            CRTmap.append(a % self.N[i])
        
        return CRTmap

    def getPreImage(self, A):
        # takes A which is the list of a1, a2, ..., ak

        if (len(A) != len(self.N)):
            return -1
        
        a = 0
        n = self.getProduct()

        for i in range(len(A)):
            ni_star = mpz(1)
            for j in range(len(A)):
                if i!=j:
                    bi = multMod(ni_star, self.N[j], self.N[i])
                    ni_star *= self.N[j]
            
            ti = getInverse(bi, self.N[i])
            if ti<0:
                ti+=self.N[i]
            
            ei = ni_star * ti

            ai_ei = multMod(A[i], ei, n)
            a+=ai_ei
            a = a%n
        
        return a

# k = CRT([mpz(3), mpz(5), mpz(7)])
# print(k.getPrimes())
# print(k.getProduct())
# print(k.getCRTMap(mpz(68)))
# print(k.getPreImage([mpz(2), mpz(3), mpz(5)]))

def GlobalSetup_correct(mu, M):
    # the max prime that we can select is 2**63, which raised to the power mu*k gives us 'P'
    # 2MP^2 < n1.n2....nk
    # if we take n1, n2, ..., nk such that n1.n2....nk > 2M*(C)^k where C = 2^(126*mu), then the strategy would always work

    C = mpfr(2**(32*mu))
    primes = []

    threshold = mpz(2*M)
    product = mpz(1)

    start = mpz(2**16 - 1)
    while (threshold >= product and start > 1):

        if start.is_prime(40) == True:
            product *= start
            primes.append(start)

            threshold *= C

        start-=2
    
    # k is the length of this list
    # print(threshold)
    # print(product)
    return primes

def GlobalSetup_incorrect(mu, M):
    # the max prime that we can select is 2**63, which raised to the power mu*k gives us 'P'
    # 2MP^2 < n1.n2....nk
    # if we take n1, n2, ..., nk such that n1.n2....nk > 2M*(C)^k where C = 2^(126*mu), then the strategy would always work

    k = int(input("Enter number of primes: "))
    primes = []

    start = mpz(2**16 - 1)
    while (k>0):

        if start.is_prime(40) == True:
            
            primes.append(start)
            k-=1


        start-=2
    
    return primes


def Transmit(A):
    global primes
    global mu
    k = len(A)

    l = random.randint(0, floor(mu*k))

    chosen = [0 for i in range(k)]
    flg = 0

    while (flg <=l):
        idx = random.randint(0, k-1)
        if (chosen[idx] == 0):
            chosen[idx] = 1
            flg+=1

    for i in range(k):
        if (chosen[i] == 1):
            while (True):
                bi = random.randint(0, primes[i]-1)
                if (bi == A[i]):
                    continue
                else:
                    break
            
            A[i] = bi
        else:
            continue

    return (A)
            
def ReedSolomonSend(a):
    global primes
    obj = CRT(primes)
    A = obj.getCRTMap(a)
    return Transmit(A)

def ReedSolomonReceive(B):
    global primes
    global M
    global mu
    # B is the list b1, b2, ..., bk

    obj = CRT(primes)
    n = obj.getProduct()
    b = obj.getPreImage(B)

    r,s,t = EGCD(n,b)
    p_dash = primes[:]

    p_dash.sort()
    l = floor(mu*len(primes))
    pd = p_dash[-1:(-1)*(l+1):-1]

    # print(l)
    # print(pd)
    
    P = 1
    for j in pd:
        P*=j
    
    r_star = M*P
    # print(2* r_star * P)
    r_dash = -1

    for i in range(len(r)):
        if r[i] <= r_star:
            t_dash = t[i]
            r_dash = r[i]
            break
    
    # print(abs(r_dash) % abs(t_dash))
    if (t_dash == 0):
        return -1
    
    if (abs(r_dash) % abs(t_dash) == 0):
        a = mpz (abs(r_dash) // abs(t_dash))
        return a
    else:
        return -1
    

primes = GlobalSetup_correct(mu, M)
# print(primes)

B= ReedSolomonSend(a)
ans=ReedSolomonReceive(B)

print()
print(f"Reconstructed message : {ans}")


Max value of message: 179769313486231590772930519078902473361797697894230657273430081157732675805500963132708477322407536021120113879871393357658789768814416622492847430639474124377767893424865485276302219601246094119453082952085005768838150682342462881473913110540827237163350510684586298239947245938479716304835356329624224137216
message : 477657824652871487175964576849758942175894678568746584758748275346587346583475983985738645876435798427586458763454327
corruption factor : 0.1

Reconstructed message : 477657824652871487175964576849758942175894678568746584758748275346587346583475983985738645876435798427586458763454327
