In [12]:
from collections import Counter, defaultdict
import datetime
import math
import numpy as np
import pandas as pd
from pprint import PrettyPrinter
import time
from tqdm import tqdm
from z3 import *

pp = PrettyPrinter().pprint

In [5]:
## This cell is an example of the variables
## required in the data.

# persons is a list of ids (for simplicity)
# we can be more sophisticated and use a 
# Person class to represent each person
persons = range(20)


# tables is a list of table sizes
# e.g. if tables = [3, 4, 5], it means
# there are 3 tables of sizes 3, 4 and 5
# respectively.
tables = [5] * 4

# groups is a dictionary that maps a group id
# to a list of all the people in that group
# e.g. if groups[0] == [1, 2], it means
# person 1 and 2 belong to group 0
groups = defaultdict(list)
groups[0] = [0, 1, 2, 3]
groups[1] = [6, 7, 8, 9]

# friends is a dictionary that maps a person_id
# to a list of all that person's friends.
# e.g. if friends[0] == [1, 2, 3], it means 
# that person 0 has three friends: persons 1, 2, 3
friends = defaultdict(list)
friends[0] = [1]
friends[1] = [2, 3, 4]

# enemies is a list of pairs where
# each pair holds two people that 
# shouldn't be seated on the same table
enemies = []

# couples is a list of pairs where
# each pair holds two people that
# must be seated on the same table
couples = []

In [6]:
df = pd.read_csv('../example_data/332_attendees.csv')

In [7]:
data_dict = dict()
for col in df.columns:
    data_dict[col] = list(filter(lambda x: not(math.isnan(x)), df[col]))
    
persons = data_dict['attendees_id']
tables = list(map(int, data_dict['table_sizes']))

get_person_index = lambda x: persons.index(x)
enemy_a_indices = map(get_person_index, data_dict['enemy_a'])
enemy_b_indices = map(get_person_index, data_dict['enemy_b'])
couple_a_indices = map(get_person_index, data_dict['must_be_together_a'])
couple_b_indices = map(get_person_index, data_dict['must_be_together_b'])
enemies = list(zip(enemy_a_indices, enemy_b_indices))
couples = list(zip(couple_a_indices, couple_b_indices))
group_member_indices = list(map(get_person_index, data_dict['group_member_id']))

groups = defaultdict(list)
for group_id, group_member_id in zip(data_dict['group_id'], group_member_indices):
    groups[group_id].append(group_member_id)

friends_set = defaultdict(set)
for _, g_members in groups.items():
    for member in g_members:
        friends_set[member] = friends_set[member].union(g_members)

friends = defaultdict(list)
for p, p_friends in friends_set.items():
    friends[p] = list(p_friends.difference(set([p])))

In [8]:
person_ids = range(len(persons))
table_ids = range(len(tables))

In [9]:
extra_bits_num = int(math.ceil(math.log(len(persons), 2)))
extra_bits_num = 0 # only use extra bits if youre doing option 1 for person constraint
bit_vec_size = len(tables) + extra_bits_num

def get_person_bv(p_id):
    return BitVec('p{}'.format(p_id), bit_vec_size)

def get_people_bvs(p_ids=person_ids):
    return map(get_person_bv, p_ids)

def get_table_mask(t_id):
    return (BitVecVal(1, bit_vec_size) << t_id)

def get_couple_constraint(c1, c2):
    c1_bv = get_person_bv(c1)
    c2_bv = get_person_bv(c2)
    return (c1_bv & c2_bv != 0)

def get_enemy_constraint(e1, e2):
    e1_bv = get_person_bv(e1)
    e2_bv = get_person_bv(e2)
    return (e1_bv & e2_bv == 0)

In [14]:
hard_cons = []

# each person can only be in one table
for p_i, _ in enumerate(persons):
    p_bv = get_person_bv(p_i)
    hard_cons.append((And((p_bv != 0), (p_bv & (p_bv - 1)) == 0)))
    hard_cons.append(ULE(p_bv, get_table_mask(table_ids[-1])))
    
# each table has a fixed capacity
for t_i, t_cap in enumerate(tables):
    p_bvs = get_people_bvs()
    
    # option 1
#     table_mask = get_table_mask(t_i)
#     t_sum = sum(map(lambda p_bv: p_bv & table_mask, p_bvs))
#     hard_cons.append(t_sum <= (table_mask * t_cap))
    
#     option 2
    t_sum = Sum(map(lambda p_bv: LShR(p_bv, t_i) & 1, p_bvs))
    hard_cons.append(t_sum <= t_cap)
    
# couples must be seated together
for c1, c2 in couples:
    c_cons = get_couple_constraint(c1, c2)
    hard_cons.append(c_cons)

# enemies must be seated separately
for e1, e2 in enemies:
    e_cons = get_enemy_constraint(e1, e2)
    hard_cons.append(e_cons)

# Friends should be seated together.
# If x only has one friend y, x should
# be seated with y. If x has more than
# one friend, x should be seated with
# at least two friends.
for p, p_friends in tqdm(friends.items()):
#     continue
    if len(p_friends) > 1:
        for t in table_ids:
            table_mask = get_table_mask(t)
            p_bvs = get_people_bvs(p_friends)
            t_sum = sum(map(lambda p_bv: p_bv & table_mask, p_bvs))
            p_on_t = (get_person_bv(p) & table_mask) != 0
            hard_cons.append(Implies(p_on_t, t_sum >= 1 * table_mask))
    elif len(p_friends) > 0:
        c_cons = get_couple_constraint(p, p_friends[0])
        hard_cons.append(c_cons)

# each table should not have more than 70% from
# any single group
for t, t_cap in tqdm(list(enumerate(tables))):
    continue
    group_cap = int(t_cap * 0.7)
    affected_groups = { k: v for k, v in groups.items() if len(v) > group_cap }
    for g, g_members in affected_groups.items():
        p_bvs = get_people_bvs(g_members)
        t_sum = sum(map(lambda p_bv: p_bv & table_mask, p_bvs))
        hard_cons.append(t_sum <= group_cap * table_mask)
    
# # tables should be as full as possible
# # this is a difficult problem, hence we 
# # add it to soft_cons
# for t in tqdm(table_ids):
#     t_sizes = []
#     t_poss = get_filtered_assignments(t_ids=[t])
#     x = [PbGe(map(lambda x: (x, 1), t_poss), i) for i in range(1, tables[t] + 1)]
#     soft_cons.extend(zip(x, range(1, tables[t] + 1)))

print ('constraints enumerated')

100%|██████████| 222/222 [00:31<00:00,  6.95it/s]
100%|██████████| 34/34 [00:00<00:00, 13066.37it/s]

constraints enumerated





In [None]:
def get_results(model):
    return map(lambda p: int(math.log(int(str(model.eval(get_person_bv(p)))), 2)), person_ids)

print datetime.datetime.now()
t_start = time.time()
s = Solver()
s.add(And(hard_cons))
f = s.sexpr()
%time check_res = s.check()
if check_res == sat:
    print ('problem solved')
    results = get_results(s.model())
#     print pd.DataFrame(results)
    print(Counter(results))
else:
    print('not satisfiable')

t_end = time.time()
total_time = t_end - t_start
print total_time
print datetime.datetime.now()

2018-05-24 17:55:41.487173


In [None]:
print datetime.datetime.now()


In [None]:
with open('f', 'w') as f:
    f.write(s.sexpr())
    f.write('\n(check-sat)\n')

In [76]:
bit_vec_size

12

In [77]:
len(table_ids)

12