In [1]:
!pip install z3-solver



In [2]:
import z3
print(z3.get_version_string())

4.8.0


In [0]:
import binascii

def str_to_int(str):
  byts = str.encode("utf-8")
  return int(binascii.hexlify(byts), 16)

def int_to_str(i):
  hx = hex(i)[2:]
  secret = binascii.unhexlify(hx)
  return secret.decode("utf-8")

In [0]:
import random

def shamir_split(m, n, secret):
  rnds = [int(random.uniform(0, 1) * secret) for _ in range(m - 1)]
  fn = generate_polynomial_fn([secret] + rnds)
  return [(x, fn(x)) for x in range(1, n + 1)]
  
def generate_polynomial_fn(a_vars):
  parts = []
  for degree in range(len(a_vars)):
    # closure on variable d using default parameter (can't use lambda here)
    def polynomial_part(x, d=degree):
      return a_vars[d] * x**d
    parts.append(polynomial_part)  
  return lambda x: sum(map(lambda f: f(x), parts))

In [0]:
from z3 import *

def shamir_resolve(splits):
  degree = len(splits)
  a = [Int(f'a{i}') for i in range(0, degree)]
  solver = Solver()
  
  for x, y in splits:
    p = generate_polynomial_fn(a)
    solver.add(y == p(x))
    
  solver.check()
  model = solver.model()
  sec_int = model[a[0]].as_long()
  return int_to_str(sec_int)

In [0]:
secret = "this is so secret"
sec_int = str_to_int(secret)
assert(int_to_str(sec_int) == secret)

In [7]:
m = 8
n = 12
splits = shamir_split(m, n, sec_int)
assert(len(splits) == n)

print(splits)

[(1, 203684099196012614703441772101570644829556), (2, 6563068433346436688981828537646297761604980), (3, 89969854494472429004150444092329312807708020), (4, 621583368174339462084651752340737666827052404), (5, 2843937020943286176605152928161929389250078068), (6, 9939906607838692814173627283622988831373026676), (7, 28761996556149697161692476351143053238417384820), (8, 72384360627142965332819853806599188204729034100), (9, 163639973573174324712053352415365489666676647284), (10, 339805370252531066386958215184276790087808869748), (11, 658595368704349724394062231902508346490829825396), (12, 1205630193685954139103939482254362887990952486260)]


In [8]:
shamir_resolve(splits[4:12])

'this is so secret'