In [1]:
from collections import Counter, defaultdict
import datetime
from functools import partial
import numpy as np
import pandas as pd
from pprint import PrettyPrinter
import time
from tqdm import tqdm
from pysmt.shortcuts import *
import math

pp = PrettyPrinter().pprint

In [2]:
## This cell is an example of the variables
## required in the data.

# persons is a list of ids (for simplicity)
# we can be more sophisticated and use a 
# Person class to represent each person
persons = range(50)


# tables is a list of table sizes
# e.g. if tables = [3, 4, 5], it means
# there are 3 tables of sizes 3, 4 and 5
# respectively.
tables = [10] * 5

# groups is a dictionary that maps a group id
# to a list of all the people in that group
# e.g. if groups[0] == [1, 2], it means
# person 1 and 2 belong to group 0
groups = defaultdict(list)
groups[0] = [0, 1, 2, 3]
groups[1] = [6, 7, 8, 9]

# friends is a dictionary that maps a person_id
# to a list of all that person's friends.
# e.g. if friends[0] == [1, 2, 3], it means 
# that person 0 has three friends: persons 1, 2, 3
friends = defaultdict(list)
friends[0] = [1]
friends[1] = [2, 3, 4]

# enemies is a list of pairs where
# each pair holds two people that 
# shouldn't be seated on the same table
enemies = []

# couples is a list of pairs where
# each pair holds two people that
# must be seated on the same table
couples = []

In [3]:
df = pd.read_csv('example_data/50_attendees.csv')

In [4]:
data_dict = dict()
for col in df.columns:
    data_dict[col] = list(filter(lambda x: not(math.isnan(x)), df[col]))
    
persons = data_dict['attendees_id']
tables = list(map(int, data_dict['table_sizes']))

get_person_index = lambda x: persons.index(x)
enemy_a_indices = map(get_person_index, data_dict['enemy_a'])
enemy_b_indices = map(get_person_index, data_dict['enemy_b'])
couple_a_indices = map(get_person_index, data_dict['must_be_together_a'])
couple_b_indices = map(get_person_index, data_dict['must_be_together_b'])
enemies = list(zip(enemy_a_indices, enemy_b_indices))
couples = list(zip(couple_a_indices, couple_b_indices))
group_member_indices = list(map(get_person_index, data_dict['group_member_id']))

groups = defaultdict(list)
for group_id, group_member_id in zip(data_dict['group_id'], group_member_indices):
    groups[group_id].append(group_member_id)

friends_set = defaultdict(set)
for _, g_members in groups.items():
    for member in g_members:
        friends_set[member] = friends_set[member].union(g_members)

friends = defaultdict(list)
for p, p_friends in friends_set.items():
    friends[p] = list(p_friends.difference(set([p])))

In [5]:
person_ids = range(len(persons))
table_ids = range(len(tables))

In [6]:
def get_person_int(p_id):
    return Symbol('p{}'.format(p_id), INT)

def get_persons_int(p_ids=person_ids):
    return map(get_person_int, p_ids)

def get_couple_constraint(c1, c2):
    c1p = get_person_int(c1)
    c2p = get_person_int(c2)
    return Equals(c1p, c2p)

def get_enemy_constraint(e1, e2):
    e1p = get_person_int(e1)
    e2p = get_person_int(e2)
    return Not(Equals(e1p, e2p))

In [7]:
hard_cons = []

# each person can only be in one table
for p_i, _ in enumerate(persons):
    p = get_person_int(p_i)
    cons = And(p >= 0, p <= table_ids[-1])
    hard_cons.append(cons)
    
# each table has a fixed capacity
for t_i, t_cap in enumerate(tables):
    ps = get_persons_int()
    cons = Plus(map(lambda p: Ite(Equals(p, Int(t_i)), Int(1), Int(0)), ps)) <= t_cap
    hard_cons.append(cons)
    
# couples must be seated together
for c1, c2 in couples:
    c_cons = get_couple_constraint(c1, c2)
    hard_cons.append(c_cons)

# enemies must be seated separately
for e1, e2 in enemies:
    e_cons = get_enemy_constraint(e1, e2)
    hard_cons.append(e_cons)

for p, p_friends in tqdm(friends.items()):
#     continue
    p_int = get_person_int(p)
    if len(p_friends) > 1:
        for t in table_ids:
            friend_ints = get_persons_int(p_friends)
            ge_2_friends = Plus(map(lambda f: Ite(Equals(f, Int(t)), Int(1), Int(0)), friend_ints)) >= 1
            hard_cons.append(Implies(Equals(p_int, Int(t)), ge_2_friends))
    elif len(p_friends) > 0:
        c_cons = get_couple_constraint(p, p_friends[0])
        hard_cons.append(c_cons)

# # each table should not have more than 70% from
# # any single group
for t, t_cap in tqdm(list(enumerate(tables))):
#     continue
    group_cap = int(t_cap * 0.7)
    affected_groups = { k: v for k, v in groups.items() if len(v) > group_cap }
    for g, g_members in affected_groups.items():
        g_member_ints = get_persons_int(g_members)
        g_members_on_t = Plus(map(lambda g: Ite(Equals(g, Int(t)), Int(1), Int(0)), g_member_ints))
        cons = g_members_on_t <= group_cap
        hard_cons.append(cons)

print ('constraints enumerated')

100%|██████████| 50/50 [00:00<00:00, 725.81it/s]
100%|██████████| 5/5 [00:00<00:00, 610.81it/s]

constraints enumerated





In [8]:
# %time get_model(And(hard_cons))
print datetime.datetime.now()
t_start = time.time()
%time model = get_model(And(hard_cons), solver_name='z3')
t_end = time.time()

total_time = t_end - t_start
print total_time
print datetime.datetime.now()

2018-05-25 01:06:55.006261
CPU times: user 390 ms, sys: 87.8 ms, total: 478 ms
Wall time: 536 ms
0.537671089172
2018-05-25 01:06:55.544822


In [9]:
# write_smtlib(And(hard_cons), 'int.smt2')

In [10]:
rev_g = defaultdict(list)
for g, g_mems in groups.items():
    for mem in g_mems:
        rev_g[mem].append(g)    

m_vals = [int(model.get_value(get_person_int(x)).serialize()) for x in person_ids]
table_vals = defaultdict(lambda: defaultdict(list))
for i, val in enumerate(m_vals):
    table_vals[val][i].extend(rev_g[i])

table_vals = {k: dict(v) for k, v in table_vals.items()}

table_group_stats = {k: Counter(reduce(lambda x, y: x + y, v.values())) for k, v in table_vals.items()}

pp(table_group_stats)
pp(table_vals)

{0: Counter({6: 4, 8: 2, 1: 2, 7: 2}),
 1: Counter({3: 3, 5: 3, 1: 2, 2: 2}),
 2: Counter({8: 2, 1: 2, 3: 2, 4: 2, 5: 2}),
 3: Counter({3: 3, 6: 3, 5: 2, 7: 2}),
 4: Counter({8: 5, 2: 3, 4: 2})}
{0: {0: [1],
     1: [1],
     32: [6],
     33: [6],
     34: [6],
     36: [6],
     37: [7],
     38: [7],
     47: [8],
     48: [8]},
 1: {2: [1],
     3: [1],
     8: [2],
     9: [2],
     16: [3],
     17: [3],
     18: [3],
     23: [5],
     24: [5],
     25: [5]},
 2: {4: [1],
     5: [1],
     13: [3],
     15: [3],
     19: [4],
     20: [4],
     28: [5],
     29: [5],
     41: [8],
     49: [8]},
 3: {11: [3],
     12: [3],
     14: [3],
     26: [5],
     27: [5],
     30: [6],
     31: [6],
     35: [6],
     39: [7],
     40: [7]},
 4: {6: [2],
     7: [2],
     10: [2],
     21: [4],
     22: [4],
     42: [8],
     43: [8],
     44: [8],
     45: [8],
     46: [8]}}
