# Mini Projekt - Baby Kyber

## Pierścień $\mathbb{Z}_{17}[X]/(X^4+1)$

In [207]:
# skopiuj pierścień ilorazowy wielomianów z pierwszych zajęć
class Zn():
    def __init__(self, value, N):
        self.N = N
        self.value = value % N
    
    def __add__(self, other):
        if isinstance(other, Zn):
            if self.N != other.N:
                ValueError("Cannot add (+) numbers from different rings")
            return Zn(self.value + other.value, self.N)
        elif isinstance(other, int):
            return Zn(self.value + other, self.N)
    
    def __radd__(self, other):
        return self + other
        
    def __sub__(self, other):
        if isinstance(other, Zn):
            if self.N != other.N:
                ValueError("Cannot subtract (-) numbers from different rings")
            return Zn(self.value - other.value, self.N)
        elif isinstance(other, int):
            return Zn(self.value - other, self.N)
        
    def __rsub__(self, other):
        if isinstance(other, Zn):
            if self.N != other.N:
                ValueError("Cannot subtract (-) numbers from different rings")
            return Zn(other.value - self.value, self.N)
        elif isinstance(other, int):
            return Zn(other.value - self.value, self.N)

    def __mul__(self, other):
        if isinstance(other, Zn):
            if self.N != other.N:
                ValueError("Cannot multiply (*) numbers from different rings")
            return Zn(self.value * other.value, self.N)
        elif isinstance(other, int):
            return Zn(self.value * other, self.N)
        
    def __rmul__(self, other):
        return self * other

    def __pow__(self, other):
        if isinstance(other, Zn):
            if self.N != other.N:
                ValueError("Cannot power (**) numbers from different rings")
            return Zn(self.value ** other.value, self.N)
        elif isinstance(other, int):
            return Zn(self.value ** other, self.N)
        
    def __rpow__(self, other):
        if isinstance(other, Zn):
            if self.N != other.N:
                ValueError("Cannot power (**) numbers from different rings")
            return Zn(other.value ** self.value, self.N)
        elif isinstance(other, int):
            return Zn(other ** self.value, self.N)

    def __repr__(self):
        return f"{self.value}"


class ZnW():
    def __init__(self, w, N, W):
        self.N = N
        self.W = [Zn(c, N) for c in W]
        self.w = [Zn(c, N) for c in w]
        self.reduce()

    def reduce(self):
        while len(self.w) >= len(self.W):
            leading_term = self.w[-1]
            if leading_term.value == 0:
                self.w.pop()
                continue
            factor = leading_term.value
            degree_diff = len(self.w) - len(self.W)
            for i in range(len(self.W)):
                self.w[degree_diff + i] -= factor * self.W[i]
            while self.w and self.w[-1].value == 0:
                self.w.pop()
            

    def __add__(self, other):
        if not isinstance(other, ZnW):
            raise ValueError("Can only add ZnW instances")
        if self.N != other.N or self.W != self.W:
            raise ValueError("ZnW instances must have the same modulus and reduction polynomial")

        max_len = max(len(self.w), len(other.w))
        result = [Zn(0, self.N) for _ in range(max_len)]
        for i in range(max_len):
            a = self.w[i] if i < len(self.w) else Zn(0, self.N)
            b = other.w[i] if i < len(other.w) else Zn(0, self.N)
            result[i] = a + b

        return ZnW([c.value for c in result], self.N, [c.value for c in self.W])

    def __radd__(self,other):
        return self + other
    
    def __mul__(self, other):
        if isinstance(other, int):
            result = [Zn(0, self.N) for _ in range(len(self.w))]
            for i in range(len(self.w)):
                result[i] = self.w[i] * other
            
            return ZnW([c.value for c in result], self.N, [c.value for c in self.W])

        elif isinstance(other, ZnW):
            if self.N != other.N or [c.value for c in self.W] != [c.value for c in other.W]:
                raise ValueError("ZnW instances must have the same modulus and reduction polynomial")

            max_len = len(self.w) + len(other.w) - 1
            result = [Zn(0,self.N) for _ in range(max_len)]
            for i, a in enumerate(self.w):
                for j, b in enumerate(other.w):
                    result[i+j] += a * b
            
            return ZnW([c.value for c in result], self.N, [c.value for c in self.W])

        else:
            raise ValueError("Can only multiply ZnW or int instances")

    def __rmul__(self, other):
        return self * other

    def __repr__(self):
        return " + ".join(f"{coef}x^{i}" for i, coef in enumerate(self.w) if coef.value != 0)

## Baby Kyber

Zaimplementuj poniższe elementy kryptosystemu Baby Kyber tak, aby osiągnąć jak największą skuteczność w testach (przy niezerowych błędach). Wymagana minimalna skuteczność to 60%.

In [268]:
import numpy as np

k = 2
q = 3329
n = 4

W = [1,0,0,0,1]

def random_ZnW(q, n):
    return ZnW([np.random.randint(0, q) for _ in range(n)], q, W)

def random_B_n(q, n):
    return ZnW(np.random.choice([-1, 0, 1], size=n, p=[1/4, 1/2, 1/4]), q, W)

def generate_A(k):
    return np.array([[random_ZnW(q, n) for _ in range(k)] for _ in range(k)])

### Generowanie klucza

Zaimplementuj funkcję `key_gen()` realizującą generowanie klucza w kryptosystemie Baby Kyber. Funkcja ma zwracać `A,t,s`. Przetestuj, czy dla podanych $A,s,e$ otrzymasz poprawny wielomian $t$.

$A=\left[\begin{matrix}
    6x^3+16x^2+16x+11&9x^3+4x^2+6x+3\\
    5x^3+3x^2+10x+1&6x^3+x^2+9x+15
\end{matrix}\right]$

$\mathbf{s}=(-x^3-x^2+x,-x^3-x)$

$\mathbf{e}=(x^2,x^2-x)$

$\mathbf{t}=A\mathbf{s}+\mathbf{e}:\ \ \mathbf{t}=(16x^3+15x^2+7,10x^3+12x^2+11x+6)$

In [269]:
def key_gen():
    # A = np.array([
    #     [ZnW([11, 16, 16, 6], q, W), ZnW([3, 6, 4, 9], q, W)],
    #     [ZnW([1, 10, 3, 5], q, W), ZnW([15, 9, 1, 6], q, W)]
    # ])

    # s = np.array([
    #     ZnW([0, 1, -1, -1], q, W), ZnW([0, -1, 0, -1], q, W)
    # ])

    # e = np.array([
    #     ZnW([0, 0, 1], q, W), ZnW([0, -1, 1], q, W)
    # ])

    A = generate_A(k)

    s = np.array([random_B_n(q, n) for _ in range(k)])

    e = np.array([random_B_n(q, n) for _ in range(k)])

    t = A @ s + e

    return A, t, s

A, t, s = key_gen()

print(A, t, s)

[[1753x^0 + 3084x^1 + 2551x^2 + 953x^3 808x^0 + 1591x^1 + 368x^2 + 387x^3]
 [1423x^0 + 2282x^1 + 386x^2 + 1802x^3
  2798x^0 + 3185x^1 + 1007x^2 + 216x^3]] [1813x^0 + 940x^1 + 1762x^2 + 589x^3 3104x^0 + 1182x^1 + 2802x^2 + 1593x^3] [1x^1 + 1x^3 3328x^0]


### Szyfrowanie

Zaimplementuj funkcję `encrypt(A,t,m)` realizującą szyfrowanie w kryptosystemie Baby Kyber a gdzie wejściowe `m` jest w postaci listy. Funkcja ma zwracać szyfrogram `c`. Przetestuj poprawność działania na poniższych danych. 

$m=1\cdot x^3+0\cdot x^2+1\cdot x+1=x^3+x+1$

$\mathbf{r}=(-x^3+x^2,x^3+x^2-1)$

$\mathbf{e_1}=(x^2+x,x^2)$

$e_2=-x^3-x^2$

$\mathbf{u}=A^T\mathbf{r}+\mathbf{e_1}:\ \ \mathbf{u}=(11x^3+11x^2+10x+3,4x^3+4x^2+13x+11)$

$v=\mathbf{t}^T\mathbf{r}+e_2+\lfloor\frac{q}{2}\rceil m:\ \ v=8x^3+6x^2+9x+16$

$\mathbf{c}=(\mathbf{u},v):\ \ \mathbf{c}=((11x^3+11x^2+10x+3,4x^3+4x^2+13x+11),8x^3+6x^2+9x+16)$

In [270]:
def encrypt(A, t, m):
    # r = np.array([
    #     ZnW([0, 0, 1, -1], q, W), ZnW([-1, 0, 1, 1], q, W)
    # ])

    # e_1 = np.array([
    #     ZnW([0, 1, 1], q, W), ZnW([0, 0, 1], q ,W)
    # ])

    # e_2 = ZnW([0, 0, -1, -1], q ,W)

    r = np.array([random_B_n(q, n) for _ in range(k)])

    e_1 = np.array([random_B_n(q, n) for _ in range(k)])

    e_2 = random_B_n(q, n)

    u = np.transpose(A) @ r + e_1

    v = np.transpose(t) @ r + e_2 + ZnW([int(np.ceil(q/2))], q, W) * ZnW(m, q, W)

    c = (u, v)

    return c


m = [1, 1, 0 ,1]
c = encrypt(A, t, m)

print(c)

(array([421x^0 + 3308x^1 + 868x^2 + 2777x^3,
       2493x^0 + 2540x^1 + 542x^2 + 2178x^3], dtype=object), 3075x^0 + 2008x^1 + 3316x^2 + 776x^3)


### Deszyfrowanie

Zaimplementuj funkcję `decrypt(c,s)` realizującą deszyfrowanie w kryptosystemie Baby Kyber. Funkcja ma zwracać ostateczną odszyfrowaną wiadomość `m_n`. Przetestuj działanie na poniższych danych.

$m_n=v-\mathbf{s}^T\mathbf{u}:\ \ m_n=8x^3+14x^2+8x+6$

$m_n=1\cdot x^3+0\cdot x^2+1\cdot x+1$


In [271]:
def decrypt(c, s):
    m_n = c[1] + (-1) * np.transpose(s) @ c[0]

    return list(map(lambda x: 0 if (x.value >= 0 and x.value < q/4) or (x.value > q/2 + q/4 and x.value < q) else 1, m_n.w))

m_n = decrypt(c, s)

print(m_n)

[1, 1, 0, 1]


### Testy

In [273]:
import secrets as sc

success = 0
for i in range(1000):
    output = []
    A,t,s = key_gen()
    
    m=[sc.choice((0,1)) for k in range(4)]
    
    c = encrypt(A,t,m)
    m_n = decrypt(c,s)

    if m_n == m:
        success += 1

print(f'Success rate: {success * 100 /1000} %')


Success rate: 100.0 %
