In [1]:
from random import randint, choice
n = 16

# Generate group Z*p

In [2]:
def is_prime(n: int):
    for i in range(2, n // 2):
        if n % i == 0:
            return False
    return True

assert is_prime(17)
assert is_prime(47)
assert not is_prime(16)

In [3]:
class Group:
    def __generate(self):
        p, g = self.p, self.g
        l = []
        cur = g
        while cur != 1:
            l.append(cur)
            cur = (cur * g) % p
        l.append(1)
        self.q = len(l)
        self.group = tuple(l)

    def __init__(self, p, g):
        assert is_prime(p), f"Value for p {p} is not a prime"
        assert is_prime(g), f"Value for g {g} is not a prime"
        assert g < p, f"Equation g < p is not satisfied when g = {g} and p = {p}"
        self.p = p
        self.g = g
        self.__generate()

    def __str__(self):
        return f"(p: {self.p}, g: {self.g}, q: {self.q}): G = {self.group}"
    
    def getValues(self):
        return self.p, self.g, self.q, self.group

    def getP(self):
        return self.p

    def getG(self):
        return self.g

    def getQ(self):
        return self.q

    def getGroup(self):
        return self.group
    
    def isMember(self, x: int):
        return x in self.group

group = Group(47, 17)
print(group)
assert group.getGroup() == (17, 7, 25, 2, 34, 14, 3, 4, 21, 28, 6, 8, 42, 9, 12, 16, 37, 18, 24, 32, 27, 36, 1)
assert group.isMember(17)
assert group.isMember(1)
assert group.isMember(42)
assert not group.isMember(15)

(p: 47, g: 17, q: 23): G = (17, 7, 25, 2, 34, 14, 3, 4, 21, 28, 6, 8, 42, 9, 12, 16, 37, 18, 24, 32, 27, 36, 1)


# Schnorr's protocol

In [4]:
class Prover:
    def __generateSignature(self):
        p, g, _, _ = self.group.getValues()
        self.h = pow(g, self.x, p)

    def __init__(self, group: Group, x: int):
        assert group.isMember(x), f"The secret x = {x} has to be in group G"
        self.group = group
        self.x = x
        self.__generateSignature()
    
    def getSignature(self):
        return self.h

    def getCommitment(self):
        p, g, q, _ = self.group.getValues()
        r = randint(0, q - 1)
        self.r = r # Save it
        a = pow(g, r, p)
        return a

    def getProof(self, e: int):
        assert self.r != None
        x = self.x
        r = self.r
        q = self.group.getQ()
        f = (r + x * e) % q
        return f

prover = Prover(group, 21)
assert group.isMember(prover.getCommitment())
assert prover.getSignature() == pow(17, 21, 47)
assert group.isMember(prover.getSignature())

In [5]:
class Verifier:
    def __init__(self, group: Group, h: int):
        assert group.isMember(h), f"h = {h} is not a signature of group G = {group}"
        self.group = group
        self.h = h

    def getChallenge(self, a: int):
        self.a = a # Simply stores it
        e = randint(0, 2**n)
        self.e = e
        return e
    
    def verifyProof(self, f: int):
        p, g, q, _ = self.group.getValues()
        assert self.a != None
        a = self.a
        assert self.e != None
        e = self.e
        h = self.h
        # e % q not very important it is done implicitly with the % p
        return pow(g, f, p) == (a * pow(h, e % q, p)) % p
    
verifier = Verifier(group, prover.getSignature())
assert 0 <= verifier.getChallenge(17) <= 2**n

In [6]:
def transcript(prover: Prover, verifier: Verifier) -> bool:
    # Only a, e and f are visible in the wild
    for i in range(100): # Run multiple time because a non legitimate user could pass
        a = prover.getCommitment()
        e = verifier.getChallenge(a)
        f = prover.getProof(e)
        if not verifier.verifyProof(f):
            return False
    return True

for i in range(1000):
    assert transcript(prover, verifier)
malicious = Prover(group, 9)
for i in range(1000):
    assert not transcript(malicious, verifier)