In [1]:
from collections import Counter, defaultdict
import datetime
import math
import pandas as pd
from pprint import PrettyPrinter
import time
from tqdm import tqdm
from z3 import *

pp = PrettyPrinter().pprint

In [2]:
## This cell is an example of the variables
## required in the data.

# persons is a list of ids (for simplicity)
# we can be more sophisticated and use a 
# Person class to represent each person
persons = range(50)


# tables is a list of table sizes
# e.g. if tables = [3, 4, 5], it means
# there are 3 tables of sizes 3, 4 and 5
# respectively.
tables = [10] * 5

# groups is a dictionary that maps a group id
# to a list of all the people in that group
# e.g. if groups[0] == [1, 2], it means
# person 1 and 2 belong to group 0
groups = defaultdict(list)
groups[0] = [0, 1, 2, 3]
groups[1] = [6, 7, 8, 9]

# friends is a dictionary that maps a person_id
# to a list of all that person's friends.
# e.g. if friends[0] == [1, 2, 3], it means 
# that person 0 has three friends: persons 1, 2, 3
friends = defaultdict(list)
friends[0] = [1]
friends[1] = [2, 3, 4]

# enemies is a list of pairs where
# each pair holds two people that 
# shouldn't be seated on the same table
enemies = []

# couples is a list of pairs where
# each pair holds two people that
# must be seated on the same table
couples = []

In [63]:
df = pd.read_csv('../example_data/332_attendees.csv')

In [64]:
data_dict = dict()
for col in df.columns:
    data_dict[col] = list(filter(lambda x: not(math.isnan(x)), df[col]))
    
persons = data_dict['attendees_id']
tables = list(map(int, data_dict['table_sizes']))

get_person_index = lambda x: persons.index(x)
enemy_a_indices = map(get_person_index, data_dict['enemy_a'])
enemy_b_indices = map(get_person_index, data_dict['enemy_b'])
couple_a_indices = map(get_person_index, data_dict['must_be_together_a'])
couple_b_indices = map(get_person_index, data_dict['must_be_together_b'])
enemies = list(zip(enemy_a_indices, enemy_b_indices))
couples = list(zip(couple_a_indices, couple_b_indices))
group_member_indices = list(map(get_person_index, data_dict['group_member_id']))

groups = defaultdict(list)
for group_id, group_member_id in zip(data_dict['group_id'], group_member_indices):
    groups[int(group_id)].append(group_member_id)

friends_set = defaultdict(set)
for _, g_members in groups.items():
    for member in g_members:
        friends_set[member] = friends_set[member].union(g_members)

friends = defaultdict(list)
for p, p_friends in friends_set.items():
    friends[p] = list(p_friends.difference(set([p])))

In [65]:
person_ids = range(len(persons))
table_ids = range(len(tables))

In [66]:
def get_person_int(p_id):
    return Int('p{}'.format(p_id))

def get_persons_int(p_ids=person_ids):
    return map(get_person_int, p_ids)

def get_couple_constraint(c1, c2):
    c1p = get_person_int(c1)
    c2p = get_person_int(c2)
    return c1p == c2p

def get_enemy_constraint(e1, e2):
    e1p = get_person_int(e1)
    e2p = get_person_int(e2)
    return e1p != e2p

In [67]:
hard_cons = []

# each person can only be in one table
for p_i, _ in enumerate(persons):
    p = get_person_int(p_i)
    cons = And(p >= 0, p <= table_ids[-1])
    hard_cons.append(cons)
    
# each table has a fixed capacity
for t_i, t_cap in enumerate(tables):
    ps = get_persons_int()
    cons = Sum(map(lambda p: If(p == t_i, 1, 0), ps)) <= t_cap
    hard_cons.append(cons)
    
# couples must be seated together
for c1, c2 in couples:
    c_cons = get_couple_constraint(c1, c2)
    hard_cons.append(c_cons)

# enemies must be seated separately
for e1, e2 in enemies:
    e_cons = get_enemy_constraint(e1, e2)
    hard_cons.append(e_cons)

for p, p_friends in tqdm(friends.items()):
    p_int = get_person_int(p)
    if len(p_friends) > 1:
        for t in table_ids:
            friend_ints = get_persons_int(p_friends)
            ge_2_friends = Sum(map(lambda f: If(f == t, 1, 0), friend_ints)) >= 2
            hard_cons.append(Implies(p_int == t, ge_2_friends))
    elif len(p_friends) > 0:
        c_cons = get_couple_constraint(p, p_friends[0])
        hard_cons.append(c_cons)

# # each table should not have more than 70% from
# # any single group
for t, t_cap in tqdm(list(enumerate(tables))):
    group_cap = int(t_cap * 0.7)
    affected_groups = { k: v for k, v in groups.items() if len(v) > group_cap }
    for g, g_members in affected_groups.items():
        g_member_ints = get_persons_int(g_members)
        g_members_on_t = Sum(map(lambda g: If(g == t, 1, 0), g_member_ints))
        cons = g_members_on_t <= group_cap
        hard_cons.append(cons)

print ('constraints enumerated')

100%|██████████| 222/222 [00:46<00:00,  4.79it/s]
100%|██████████| 34/34 [00:02<00:00, 12.49it/s]

constraints enumerated





In [68]:
# %time get_model(And(hard_cons))
print datetime.datetime.now()
t_start = time.time()
s = Solver()
s.add(And(hard_cons))
%time check_sat = s.check()
if check_sat == sat:
    print 'sat'
    model = s.model()
else:
    print "unsat"
t_end = time.time()

total_time = t_end - t_start
print total_time
print datetime.datetime.now()

2018-05-25 13:57:45.230351
CPU times: user 30min 14s, sys: 6.61 s, total: 30min 21s
Wall time: 30min 35s
sat
1835.97564197
2018-05-25 14:28:21.206950


In [69]:
rev_g = defaultdict(list)
for g, g_mems in groups.items():
    for mem in g_mems:
        rev_g[mem].append(g)    

group_mems = set(reduce(lambda x, y: x + y, groups.values()))
no_group_mems = set(person_ids) - group_mems
        
for p in no_group_mems:
    rev_g[p].append(None)
        
m_vals = [int(str(model.eval(get_person_int(x)))) for x in person_ids]
table_vals_raw = defaultdict(lambda: defaultdict(list))
for i, val in enumerate(m_vals):
    table_vals_raw[val][i].extend(rev_g[i])

table_vals = {k: dict(v) for k, v in table_vals_raw.items()}

table_group_stats = {k: Counter(reduce(lambda x, y: x + y, v.values())) for k, v in table_vals.items()}

pp(table_group_stats)

{0: Counter({33: 4, 35: 3, 37: 3, 41: 1, None: 1}),
 1: Counter({37: 6, 19: 3, 35: 2, None: 1}),
 2: Counter({24: 4, None: 3, 6: 3, 27: 2, 25: 1}),
 3: Counter({None: 4, 40: 3, 20: 3}),
 4: Counter({38: 6, 23: 3, 5: 1, None: 1, 19: 1, 20: 1}),
 5: Counter({None: 5, 20: 3}),
 6: Counter({None: 9}),
 7: Counter({None: 6, 40: 4}),
 8: Counter({None: 6, 14: 4, 13: 1}),
 9: Counter({None: 7, 8: 3}),
 10: Counter({31: 4, 17: 3, 12: 3, 9: 2, 23: 2}),
 11: Counter({None: 5}),
 12: Counter({None: 7, 42: 3}),
 13: Counter({41: 5, 25: 4, 32: 2, 10: 1, 12: 1}),
 14: Counter({None: 10}),
 15: Counter({None: 4, 45: 4, 42: 3}),
 16: Counter({None: 7, 39: 3}),
 17: Counter({5: 3, 44: 3, 27: 3, 4: 1, 7: 1, 17: 1, None: 1}),
 18: Counter({37: 4, None: 3, 23: 3, 33: 1}),
 19: Counter({None: 10}),
 20: Counter({None: 10}),
 21: Counter({25: 4, None: 4, 28: 3, 34: 1, 31: 1}),
 22: Counter({None: 4, 18: 3, 45: 3}),
 23: Counter({1: 5, 12: 4, 23: 2, 2: 1, 6: 1, 16: 1, None: 1}),
 24: Counter({38: 4, 16: 3, 1