In [1]:
from collections import Counter, defaultdict
import datetime
import math
import pandas as pd
from pprint import PrettyPrinter
from pysmt.shortcuts import *
import time
from tqdm import tqdm

pp = PrettyPrinter().pprint

In [2]:
## This cell is an example of the variables
## required in the data.

# persons is a list of ids (for simplicity)
# we can be more sophisticated and use a 
# Person class to represent each person
persons = range(50)


# tables is a list of table sizes
# e.g. if tables = [3, 4, 5], it means
# there are 3 tables of sizes 3, 4 and 5
# respectively.
tables = [10] * 5

# groups is a dictionary that maps a group id
# to a list of all the people in that group
# e.g. if groups[0] == [1, 2], it means
# person 1 and 2 belong to group 0
groups = defaultdict(list)
groups[0] = [0, 1, 2, 3]
groups[1] = [6, 7, 8, 9]

# friends is a dictionary that maps a person_id
# to a list of all that person's friends.
# e.g. if friends[0] == [1, 2, 3], it means 
# that person 0 has three friends: persons 1, 2, 3
friends = defaultdict(list)
friends[0] = [1]
friends[1] = [2, 3, 4]

# enemies is a list of pairs where
# each pair holds two people that 
# shouldn't be seated on the same table
enemies = []

# couples is a list of pairs where
# each pair holds two people that
# must be seated on the same table
couples = []

In [3]:
df = pd.read_csv('example_data/332_attendees.csv')

In [4]:
data_dict = dict()
for col in df.columns:
    data_dict[col] = list(filter(lambda x: not(math.isnan(x)), df[col]))
    
persons = data_dict['attendees_id']
tables = list(map(int, data_dict['table_sizes']))

get_person_index = lambda x: persons.index(x)
enemy_a_indices = map(get_person_index, data_dict['enemy_a'])
enemy_b_indices = map(get_person_index, data_dict['enemy_b'])
couple_a_indices = map(get_person_index, data_dict['must_be_together_a'])
couple_b_indices = map(get_person_index, data_dict['must_be_together_b'])
enemies = list(zip(enemy_a_indices, enemy_b_indices))
couples = list(zip(couple_a_indices, couple_b_indices))
group_member_indices = list(map(get_person_index, data_dict['group_member_id']))

groups = defaultdict(list)
for group_id, group_member_id in zip(data_dict['group_id'], group_member_indices):
    groups[int(group_id)].append(group_member_id)

friends_set = defaultdict(set)
for _, g_members in groups.items():
    for member in g_members:
        friends_set[member] = friends_set[member].union(g_members)

friends = defaultdict(list)
for p, p_friends in friends_set.items():
    friends[p] = list(p_friends.difference(set([p])))

In [5]:
person_ids = range(len(persons))
table_ids = range(len(tables))

In [6]:
def get_person_int(p_id):
    return Symbol('p{}'.format(p_id), INT)

def get_persons_int(p_ids=person_ids):
    return map(get_person_int, p_ids)

def get_couple_constraint(c1, c2):
    c1p = get_person_int(c1)
    c2p = get_person_int(c2)
    return Equals(c1p, c2p)

def get_enemy_constraint(e1, e2):
    e1p = get_person_int(e1)
    e2p = get_person_int(e2)
    return Not(Equals(e1p, e2p))

In [7]:
hard_cons = []

# each person can only be in one table
for p_i, _ in enumerate(persons):
    p = get_person_int(p_i)
    cons = And(p >= 0, p <= table_ids[-1])
    hard_cons.append(cons)
    
# each table has a fixed capacity
for t_i, t_cap in enumerate(tables):
    ps = get_persons_int()
    cons = Plus(map(lambda p: Ite(Equals(p, Int(t_i)), Int(1), Int(0)), ps)) <= t_cap
    hard_cons.append(cons)
    
# couples must be seated together
for c1, c2 in couples:
    c_cons = get_couple_constraint(c1, c2)
    hard_cons.append(c_cons)

# enemies must be seated separately
for e1, e2 in enemies:
    e_cons = get_enemy_constraint(e1, e2)
    hard_cons.append(e_cons)

for p, p_friends in tqdm(friends.items()):
    p_int = get_person_int(p)
    if len(p_friends) > 1:
        for t in table_ids:
            friend_ints = get_persons_int(p_friends)
            ge_2_friends = Plus(map(lambda f: Ite(Equals(f, Int(t)), Int(1), Int(0)), friend_ints)) >= 2
            hard_cons.append(Implies(Equals(p_int, Int(t)), ge_2_friends))
    elif len(p_friends) > 0:
        c_cons = get_couple_constraint(p, p_friends[0])
        hard_cons.append(c_cons)

# # each table should not have more than 70% from
# # any single group
for t, t_cap in tqdm(list(enumerate(tables))):
    group_cap = int(t_cap * 0.7)
    affected_groups = { k: v for k, v in groups.items() if len(v) > group_cap }
    for g, g_members in affected_groups.items():
        g_member_ints = get_persons_int(g_members)
        g_members_on_t = Plus(map(lambda g: Ite(Equals(g, Int(t)), Int(1), Int(0)), g_member_ints))
        cons = g_members_on_t <= group_cap
        hard_cons.append(cons)

print ('constraints enumerated')

100%|██████████| 222/222 [00:02<00:00, 86.14it/s]
100%|██████████| 34/34 [00:00<00:00, 220.34it/s]

constraints enumerated





In [8]:
# %time get_model(And(hard_cons))
print datetime.datetime.now()
t_start = time.time()
%time model = get_model(And(hard_cons), solver_name='yices')
if model is None:
    print "unsat"
t_end = time.time()

total_time = t_end - t_start
print total_time
print datetime.datetime.now()

2018-05-25 01:33:39.267255
CPU times: user 51.2 s, sys: 200 ms, total: 51.4 s
Wall time: 51.5 s
51.5444221497
2018-05-25 01:34:30.815453


In [9]:
rev_g = defaultdict(list)
for g, g_mems in groups.items():
    for mem in g_mems:
        rev_g[mem].append(g)    

group_mems = set(reduce(lambda x, y: x + y, groups.values()))
no_group_mems = set(person_ids) - group_mems
        
for p in no_group_mems:
    rev_g[p].append(None)
        
m_vals = [int(model.get_value(get_person_int(x)).serialize()) for x in person_ids]
table_vals_raw = defaultdict(lambda: defaultdict(list))
for i, val in enumerate(m_vals):
    table_vals_raw[val][i].extend(rev_g[i])

table_vals = {k: dict(v) for k, v in table_vals_raw.items()}

table_group_stats = {k: Counter(reduce(lambda x, y: x + y, v.values())) for k, v in table_vals.items()}

pp(table_group_stats)

{0: Counter({33: 4, 35: 3, None: 3, 28: 1, 22: 1}),
 1: Counter({20: 4, 24: 3, 31: 3}),
 2: Counter({8: 3, None: 3, 23: 3, 21: 2, 30: 2}),
 3: Counter({None: 7, 44: 3, 43: 1}),
 4: Counter({None: 7, 18: 3}),
 5: Counter({None: 10}),
 6: Counter({None: 10}),
 7: Counter({None: 10}),
 8: Counter({None: 10}),
 9: Counter({None: 10}),
 10: Counter({None: 6, 26: 4}),
 11: Counter({14: 4, 19: 3, None: 3, 13: 1}),
 12: Counter({40: 3, 19: 3, 39: 3, None: 1}),
 13: Counter({40: 4, 42: 3, 45: 3, None: 1}),
 14: Counter({45: 4, 42: 3, 6: 3, 10: 1, 36: 1}),
 15: Counter({23: 5, 3: 3, 12: 3, 17: 1}),
 16: Counter({37: 5, 35: 3, None: 3, 29: 1}),
 17: Counter({34: 3, 9: 3, 27: 3, None: 2, 25: 2, 36: 1, 6: 1, 16: 1, 23: 1}),
 18: Counter({37: 4, 35: 3, 16: 3, 33: 1, 2: 1, 11: 1}),
 19: Counter({38: 6, None: 3}),
 20: Counter({36: 3, 37: 3, 1: 2, 6: 2, None: 2, 23: 2, 10: 1, 4: 1}),
 21: Counter({41: 4, 43: 4, 32: 2, 10: 2, 2: 1, 4: 1, 44: 1, None: 1}),
 22: Counter({37: 6, 35: 3, 27: 3, 6: 1, 7: 1, 