In [0]:
!pip install z3-solver

In [0]:
import binascii

def str_to_int(str):
  byts = str.encode("utf-8")
  return int(binascii.hexlify(byts), 16)

def int_to_str(i):
  hx = hex(i)[2:]
  secret = binascii.unhexlify(hx)
  return secret.decode("utf-8")

In [0]:
import random

def generate_polynomial_fn(a_vars):
  parts = []
  for degree in range(len(a_vars)):
    # closure on variable d using default parameter (can't use lambda here)
    def polynomial_part(x, d=degree):
      return a_vars[d] * x**d
    parts.append(polynomial_part)  
  return lambda x: sum(map(lambda f: f(x), parts))

def shamir_split(m, n, secret):
  rnds = [int(random.uniform(0, 1) * secret) for _ in range(m - 1)]
  fn = generate_polynomial_fn([secret] + rnds)
  return [(x, fn(x)) for x in range(1, n + 1)]

In [0]:
from z3 import *

def shamir_resolve(splits):
  degree = len(splits)
  a = [Int(f'a{i}') for i in range(0, degree)]
  solver = Solver()
  
  for x, y in splits:
    p = generate_polynomial_fn(a)
    solver.add(y == p(x))
    
  solver.check()
  model = solver.model()
  sec_int = model[a[0]].as_long()
  return int_to_str(sec_int)

In [0]:
secret = "this is so secret"
sec_int = str_to_int(secret)
assert(int_to_str(sec_int) == secret)

In [6]:
m = 8
n = 12
splits = shamir_split(m, n, sec_int)
assert(len(splits) == n)

print(splits)

[(1, 180650296923632994258640862588913396639092), (2, 2708628690954982489456534789611529708922228), (3, 23215261202103551166156938619912820062578036), (4, 123011153429406964967418404662929110277973364), (5, 477704806284841568037772296231917209747875188), (6, 1497339801462192327152940840182444347026990452), (7, 4013482076460702038787620174022863719633151348), (8, 9548649197441519226288757266552383176218666356), (9, 20683473538194964114621749590982334549270357364), (10, 41536991273496632070156996429496251156496803188), (11, 78375449095130353892964230688204360989061309812), (12, 140365020558856032349082060101448103104819127668)]


In [7]:
shamir_resolve(splits[3:11])

'this is so secret'