In [1]:
!pip install z3-solver



In [0]:
# Definition of simple functions to encode a string into an integer and back

import binascii

def str_to_int(str):
  byts = str.encode("utf-8")
  return int(binascii.hexlify(byts), 16)

def int_to_str(i):
  hx = hex(i)[2:]
  secret = binascii.unhexlify(hx)
  return secret.decode("utf-8")

In [0]:
import random

# Function that generates another function given a list of coefficient.
# (also determining the polynomial's degree)
# https://en.wikipedia.org/wiki/Shamir%27s_Secret_Sharing#Usage
def generate_polynomial_fn(a_vars):
  parts = []
  for degree in range(len(a_vars)):
    # closure on variable 'degree' using default parameter
    def polynomial_part(x, d=degree):
      return a_vars[d] * x**d
    parts.append(polynomial_part)
    
  # Returns a lambda that sums all part of the polynomial for a given 'x'
  return lambda x: sum(map(lambda f: f(x), parts))

In [0]:
def shamir_split(m, n, secret):
  # Choosing random numbers that are not too far away from our secret
  rnds = [int(random.uniform(0.2, 0.8) * secret) for _ in range(m - 1)]
  f = generate_polynomial_fn([secret] + rnds)
  # returns the evaluation of the polynomial function for 0 < x < n+1
  # WARNING: here x=0 would reveal the secret (f(0) == secret)
  return [(x, f(x)) for x in range(1, n + 1)]

In [0]:
from z3 import *

# https://en.wikipedia.org/wiki/Shamir%27s_Secret_Sharing#Solution
def shamir_resolve(splits):
  degree = len(splits)
  
  # Z3 way of declaring unknown variables
  a = [Int(f'a{i}') for i in range(0, degree)]
  solver = Solver()
  
  # Create a simple system of equations from the splits
  for x, y in splits:
    p = generate_polynomial_fn(a)
    solver.add(y == p(x))
  
  # This part is Z3 specific, we check and solve the system of equations
  solver.check()
  model = solver.model()
  sec_int = model[a[0]].as_long()
  
  return int_to_str(sec_int)

In [0]:
secret = "this is so secret"

# We encode our secret into an integer
sec_int = str_to_int(secret)

# Make sure our reverse function works as expected
assert(int_to_str(sec_int) == secret)

In [7]:
m = 7
n = 12
splits = shamir_split(m, n, sec_int)

assert(len(splits) == n)
print(splits)

[(1, 128578265812178243783178166342781756401012), (2, 1929330870138865244530456899917352686216564), (3, 15336725934380767844321927151285452194080116), (4, 71837137908574446326760744108841065260410228), (5, 244273538839691018525711883516500403855451508), (6, 673035469417389632886173932031131345992443252), (7, 1598673901339026316157179490230839621761262964), (8, 3400940990993918195716724188274111746342544756), (9, 6644254724466863096528724314209814700002272628), (10, 12129588453860914512732003054938052355066848628), (11, 20952785324939411953861305349821878649878635892), (12, 34569297596087266665700341356949867509731976564)]


In [8]:
# Picking arbitrarily a subset of the splits
split_subset = splits[3:10]

# Make sure the chosen subset is greater or equal to m
assert(len(split_subset) >= m)

# Reconstruct the secret using only the splits (and the 'x' value)
shamir_resolve(split_subset)

'this is so secret'