# A*-search using `aima-python`

Algorithm implementations taken from [here](https://github.com/aimacode/aima-python/blob/master/search-4e.ipynb)

* *State* is defined by gifts in bags

* *Goal states* are defined by filled bags satisfying problem conditions

* *Actions* : put a gift in a bag with a minimal weight

In [1]:
# https://ipython.org/ipython-doc/3/config/extensions/autoreload.html
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
np.random.seed(2016)

In [3]:
from search import Problem, astar_search, uniform_cost_search, state_sequence, action_sequence
import sys
sys.path.append('../common')
from utils import weight3 as weight_fn

In [4]:
n_bags = 1
max_weight = 50

n_horses = 1000
n_balls = 1100
n_bikes = 500
n_trains = 1000
n_coals = 166
n_books = 1200
n_dolls = 1000
n_blocks = 1000
n_gloves = 200

available_gifts = {
    "horse": n_horses,
    "ball": n_balls,
    "bike": n_bikes,
    "train": n_trains,
    "coal": n_coals,
    "book": n_books,
    "doll": n_dolls,
    "blocks": n_blocks,
    "gloves": n_gloves
}

type_cost = {
#     "horse": 1.2,
#     "ball": 1.2,
#     "bike": 2,
#     "train": 1.15,
    "coal": 0.5,
#     "book": 2.0,
#     "doll": 1.3,
#     "blocks": 0.5,
#     "gloves": 3    
}

gift_types = sorted(list(available_gifts.keys()))
n_types = len(gift_types)

In [5]:
len(gift_types), gift_types

(9,
 ['ball',
  'bike',
  'blocks',
  'book',
  'coal',
  'doll',
  'gloves',
  'horse',
  'train'])

State is tuple (bags) of tuples (gifts) :

```
( 
#  ball, bike, block, book, coal, doll, gloves, horse, train  
    (0,1,0,3,0,0,0,0,2), # bag 1
    (0,0,0,0,0,2,5,6,0), # bag 2
    ...
)
```


In [6]:
def bag_weight(bag, n1=100):
    weight = 0
    for index, count in enumerate(bag):
        for i in range(count):
            weight += weight_fn(index, n1)
    return weight

In [7]:
def score(state, count=100):
    scores = np.zeros(count)
    for c in range(count):
        score = 0
        for bag in state:
            total_weight_ = bag_weight(bag, n1=1)
            if total_weight_ < max_weight:
                score += total_weight_
        scores[c] = score
    return np.mean(scores)

In [8]:
s = ((0, 0, 3, 0, 0, 1, 0, 0, 0),)
score(s)

36.175530726146057

In [9]:
available_gifts

{'ball': 1100,
 'bike': 500,
 'blocks': 1000,
 'book': 1200,
 'coal': 166,
 'doll': 1000,
 'gloves': 200,
 'horse': 1000,
 'train': 1000}

In [768]:
class SantasBagsProblem(Problem):
            
    def _get_gift_type_indices(self, state):
        out = []
        types = np.sum(np.array(state), axis=0)
        for index, t in enumerate(types):
            if t < self.available_gifts[self.gift_types[index]]:
                out.append(index)
        return out
        
    def actions(self, state):
        """Return a list of actions executable in this state."""                        
        _gift_type_indices = self._get_gift_type_indices(state)
        if len(_gift_type_indices) == 0:
            print("No gifts available to create actions")
            return []
    
        print("_gift_type_indices : ", _gift_type_indices)
        # find a bag with a minimal weight  
        min_weight_bag_index = 0
        min_weight = self.max_weight
        for i, bag in enumerate(state):
            w = self.bag_weight_fn(bag)
            if min_weight > w:
                min_weight_bag_index = i
                min_weight = w
                
#         print("min_weight_bag_index : ", min_weight_bag_index)
        
        actions = []
        bag_weight = self.bag_weight_fn(state[min_weight_bag_index])
        for _index in _gift_type_indices:
            gift_weight = self.weight_fn(_index, 50)        
            if bag_weight + gift_weight < self.max_weight:
                actions.append((min_weight_bag_index, _index))
        
        print("actions: ", actions)
#         if len(actions) == 0:
#             print("No actions found for the state : ", state, min_weight_bag_index, bag_weight)

        return actions
    
    def result(self, state, action):
        """The state that results from executing this action in this state."""
        bag_id, gift_type_index = action
        print("-- result : input state: ", state, "action: ", action)                
        new_state = list(state)
        bag = list(new_state[bag_id])
        bag[gift_type_index] += 1
        new_state[bag_id] = tuple(bag)
        print("-- result : output state: ", new_state)
        return tuple(new_state)
        
    def is_goal(self, state):
        """True if the state is a goal."""        
        for bag in state:
            if sum(bag) < 3:
#                 print("- A bag with less than 3 gifts found : ", state)
                return False
        
        # Check if solution is available:
        types = np.sum(np.array(state), axis=0)
        for index, t in enumerate(types):
            if t > self.available_gifts[self.gift_types[index]]:
                return False

        mean_score = self._validation(state)
#         if mean_score > self.goal_score:
        print("- Mean score : ", mean_score, " / ", self.goal_score, state)
        return mean_score > self.goal_score

    def step_cost(self, state, action, result=None):
        """The cost of taking this action from this state."""
        if self.type_cost is not None:
            bag_id, gift_type_index  = action
            gift_type = self.gift_types[gift_type_index]
            if gift_type in self.type_cost:
                return self.type_cost[gift_type]  # Override this if actions have different costs
            return 1.0
        return 1.0
                
    def _validation(self, state, count=100):
        scores = np.zeros(count)
        rejected = 0
        for c in range(count):
            score = 0
            for bag in state:
                total_weight_ = self.bag_weight_fn(bag, n1=1)
                if total_weight_ < self.max_weight:
                    score += total_weight_
                else:
                    rejected += 1
            scores[c] = score
#         if rejected > 0:
#             print("Rejected bags : %f / %i" % (rejected*1.0/count, len(state)))
        return np.mean(scores)

In [769]:
alpha = 0.7
goal_score = n_bags*max_weight*alpha
print("Goal score: ", goal_score)

('Goal score: ', 35.0)


In [770]:
def compute_normal_identical(mu, sigma, a=50):
    """
    Solve n*mu + 3*sigma*sqrt(n) < a
    return: n, Mu, Sigma
    """
    for n in range(100):
        y = mu*n + 3.0*sigma*np.sqrt(n)
        if y > a:
            break  
    n -= 1
    m = n*mu
    s = sigma*np.sqrt(n)
    return n, m, s

In [771]:
# # Horse : 
# ag = deepcopy(available_gifts)
# n, m, s = compute_normal_identical(5, 2)
# initial_state=tuple([tuple([ag['horse'].pop() for i in range(n)]) for j in range(n_bags)])
# initial_state

In [772]:
initial_state=tuple([tuple([0]*n_types)]*n_bags)

In [773]:
# initial_state = ((0, 0, 2, 0, 0, 1, 0, 0, 0), )
# initial_state

In [774]:
p = SantasBagsProblem(initial=initial_state,
                      gift_types=gift_types, 
                      available_gifts=available_gifts,
                      max_weight=max_weight,    
                      type_cost=type_cost,
                      weight_fn=weight_fn,
                      bag_weight_fn=bag_weight,
                      goal_score=goal_score)

Define heuristic function :


In [775]:
from math import floor

def round_value(x, digits=1):
    x = floor(x*10**digits) / 10**digits
    return x


def h12(state):     
#     h1 = 0
#     h2 = 0
#     for bag in state:
#         h1 += bag_weight(bag, 200)
#         w = bag_weight(bag, 200)
#         h2 += abs(max_weight*alpha - w)
#     h1 = abs(goal_score - h1) / n_bags    
#     return max(h1, h2) 
    return round_value((goal_score - score(state))**2/goal_score**2)


def h3(state):     
    return round_value(alpha*(max_weight*n_bags - score(state))**2/(max_weight*n_bags)**2)


def h4(state): 
    h1 = 0
    for bag in state:
        h1 += bag_weight(bag, 200)
    return abs(goal_score - h1)/goal_score


def h5(state):
    h = 0
    for bag in state:
        h += np.count_nonzero(bag) * 1.0 / len(bag)
    return h


def h6(state):
    h = 0.0
#     for bag in state:
#         h += (len(bag) - np.count_nonzero(bag)) / len(bag)
    return h * 1.0  / len(state)
                

def h7(state):
    rejected = 0
    count = 100
    for c in range(count):
        for bag in state:
            total_weight_ = bag_weight(bag, n1=1)
            if total_weight_ > max_weight:
                rejected += 1
    return rejected * 1.0 / count
        


def final_heuristic_fn(state):  
    res = np.max(np.array([h12(state), h3(state), h4(state), h5(state), h7(state)]))
#     res = (h12(state) + h5(state) + h6(state) + h7(state))/max_weight
#     print ("final_heuristic_fn : ", res, state)
    return res
#     return np.max(np.array([h12(state),]))
#     return np.max(np.array([h12(state), h3(state), h4(state), h5(state)]))

In [776]:
initial_state = ((1, 1, 2, 2, 1, 0, 1, 1, 1), )
h12(initial_state), h3(initial_state), h5(initial_state), h7(initial_state), final_heuristic_fn(initial_state)

(0.9, 0.6, 0.8888888888888888, 0.97, 1.5656165672183948)

In [777]:
ag2 = available_gifts.copy()
# print(update_available_gifts(ag2, result.state))
ag2['blocks'] = 0
ag2['doll'] = 0
ag2['bike'] = 0
ag2['horse'] = 0
ag2['train'] = 0
ag2['ball'] = 2
ag2['book'] = 2
ag2['gloves'] = 0
ag2

{'ball': 2,
 'bike': 0,
 'blocks': 0,
 'book': 2,
 'coal': 166,
 'doll': 0,
 'gloves': 0,
 'horse': 0,
 'train': 0}

In [778]:
from time import time
p.initial = ((0, 0, 0, 0, 1, 0, 0, 0, 0), )# initial_state
p.available_gifts = ag2

In [779]:
type_cost = {
#     "horse": 150,
#     "ball": 0.1,
#     "bike": 150,
#     "train": 150,
    "coal": 0.150,
    "book": 0.750,
#     "doll": 150,
#     "blocks": 150,
    "gloves": 0.750    
}
type_cost

{'book': 0.75, 'coal': 0.15, 'gloves': 0.75}

In [783]:
tic = time()
result = astar_search(p, final_heuristic_fn, verbose=True)
print(result)
print("Elapsed: ", time() - tic)

('Frontier : ', [((0, 0, 0, 0, 1, 0, 0, 0, 0),)])
('Check node: ', ((0, 0, 0, 0, 1, 0, 0, 0, 0),), ' | ', 0, ', ', 0.33714162627222038)
('_gift_type_indices : ', [0, 3, 4])
('actions: ', [(0, 0), (0, 3), (0, 4)])
('-- result : input state: ', ((0, 0, 0, 0, 1, 0, 0, 0, 0),), 'action: ', (0, 0))
('-- result : output state: ', [(1, 0, 0, 0, 1, 0, 0, 0, 0)])
('-- result : input state: ', ((0, 0, 0, 0, 1, 0, 0, 0, 0),), 'action: ', (0, 3))
('-- result : output state: ', [(0, 0, 0, 1, 1, 0, 0, 0, 0)])
('-- result : input state: ', ((0, 0, 0, 0, 1, 0, 0, 0, 0),), 'action: ', (0, 4))
('-- result : output state: ', [(0, 0, 0, 0, 2, 0, 0, 0, 0)])
('Frontier : ', [((1, 0, 0, 0, 1, 0, 0, 0, 0),), ((0, 0, 0, 1, 1, 0, 0, 0, 0),), ((0, 0, 0, 0, 2, 0, 0, 0, 0),)])
('Check node: ', ((0, 0, 0, 0, 2, 0, 0, 0, 0),), ' | ', 0.15, ', ', 0.53000000000000003)
('_gift_type_indices : ', [0, 3, 4])
('actions: ', [(0, 0), (0, 3)])
('-- result : input state: ', ((0, 0, 0, 0, 2, 0, 0, 0, 0),), 'action: ', (0, 0))
(

In [767]:
result.state

AttributeError: 'NoneType' object has no attribute 'state'

In [26]:
h12(result.state), h3(result.state), h6(result.state), h7(result.state)

(3.0714772504384142e-05, 0.051752795317218887, 0.0, 0.05)

In [456]:
p._validation(((0, 0, 0, 0, 1, 0, 0, 0, 0),))

21.887851989590789

In [457]:
def update_available_gifts(ag, state):
    sum_gifts = np.sum(np.array(state), axis=0)
    for v, gift_type in zip(sum_gifts, gift_types):
        assert ag[gift_type] - v >= 0, "Found state is not available : {}, {}".format(state, ag)
        ag[gift_type] = ag[gift_type] - v

In [505]:
type_cost = {
#       "ball": 0.9,
#     "horse": 1.2,
#     "bike": 0.5,
#     "train": 0.9,
#     "coal": 0.7,
#     "book": 1.0,
#     "doll": 1.0,
    "blocks": 0.5,
#     "gloves": 0.3    
}

In [506]:
# type_cost = {"horse": 0.9, "train": 0.9, "bike": 1.9, "book": 1.9, "gloves": 1.9, "ball": 1.9}

In [507]:
def remove_gifts(state, gifts_to_remove=2):
    _gift_removed = 0
    new_state = list(state)
    for bag_index, bag in enumerate(state):
        for i in range(gifts_to_remove):
            gift_type_index = np.argmax(bag)
            if bag[gift_type_index] > 0:
                bag = list(new_state[bag_index])
                bag[gift_type_index] -= 1
                new_state[bag_index] = tuple(bag)
                _gift_removed += 1
    if _gift_removed == 0:
        state=tuple([tuple([0]*n_types)]*n_bags)
    else:
        print("-- Remove some gift : ", state, tuple(new_state))
        state = tuple(new_state)
    return state


# def remove_gifts2(state, gifts_to_remove=1):
#     _gift_removed = 0
#     new_state = list(state)
#     for bag_index, bag in enumerate(state):
#         for i in range(gifts_to_remove):
            
#             for g in bag:
#             gift_type_index = np.argmax(bag)
#             weight_fn(index, n1)
            
#             if bag[gift_type_index] > 0:
#                 bag = list(new_state[bag_index])
#                 bag[gift_type_index] -= 1
#                 new_state[bag_index] = tuple(bag)
#                 _gift_removed += 1
#     if _gift_removed == 0:
#         state=tuple([tuple([0]*n_types)]*n_bags)
#     else:
#         print("-- Remove some gift : ", state, tuple(new_state))
#         state = tuple(new_state)
#     return state

In [508]:
from copy import deepcopy

In [509]:
total_n_bags = 1000
n_bags = 1

total_state=[]
found_goal_states=[]
ag=deepcopy(available_gifts)
counter = 0

In [510]:
alpha = 0.7
goal_score = n_bags*max_weight*alpha
print("Goal score: ", goal_score)


('Goal score: ', 35.0)


Each bag is filled using A\*-search algorithm. However, the initial state is not always the empty bag. 

When a *goal* state is found, its stored in a list without repeating. Next state is searched from the previous found goal state as initial. If nothing is found a state from stored goal states is used to restart the search. If no stored goal states, remove some gifts and restart the search


In [523]:
gifts_to_remove = 2
empty_state = tuple([tuple([0]*n_types)]*n_bags)
state=(total_state[-1],) if len(total_state) > 0 else empty_state

while n_bags * counter < total_n_bags:
    
    print("Filled bags : ", n_bags * counter, "/", total_n_bags)
    p = SantasBagsProblem(initial=tuple(state),
                          gift_types=gift_types, 
                          available_gifts=ag,
                          max_weight=max_weight,    
                          type_cost=type_cost,
                          weight_fn=weight_fn,
                          bag_weight_fn=bag_weight,
                          goal_score=goal_score)
    tic = time()
    result = astar_search(p, final_heuristic_fn, verbose=True)
    if result is not None:
        print("- Got a result")
        update_available_gifts(ag, result.state)
        if len(found_goal_states) == 0 or found_goal_states[-1] != result.state:
            found_goal_states.append(result.state)
        total_state += result.state
        counter += 1
        state=(total_state[-1],)
    else:
        print("-- Result is none | len(found_goal_states)=", len(found_goal_states))
        if len(found_goal_states) > 0:
            state=found_goal_states.pop()
            print("--- Restart from : ", state)
        else:
            if state != empty_state:
                state=remove_gifts(state)
            else:
                alpha -= 0.05
                goal_score = n_bags*max_weight*alpha
                print(">>> Goal score changed: ", goal_score)    
        
    if counter > 0 and (n_bags * counter % 20) == 0:
        s = score(total_state)
        print(">>> Current score: ", s, s * (total_n_bags) / (n_bags * counter) )
        
    if counter > 0 and (n_bags * counter % 30) == 0:
        print(">>> Currently available gifts : ", [(k,ag[k]) for k in gift_types])
        
    print("- Elapsed: ", time() - tic)

('Filled bags : ', 334, '/', 1000)
('Check node: ', ((1, 0, 1, 0, 0, 3, 0, 0, 1),), ' | ', 0, ', ', 0.66666666666666663)
('Check node: ', ((1, 0, 1, 1, 0, 3, 0, 0, 1),), ' | ', 0.75, ', ', 1.5277777777777777)
('Check node: ', ((1, 0, 1, 0, 0, 3, 1, 0, 1),), ' | ', 0.75, ', ', 1.5277777777777777)
('Check node: ', ((1, 0, 1, 0, 0, 4, 0, 0, 1),), ' | ', 1.0, ', ', 1.7777777777777777)
('Check node: ', ((2, 0, 1, 0, 0, 3, 0, 0, 1),), ' | ', 1.0, ', ', 1.7777777777777777)
('Check node: ', ((1, 0, 1, 0, 0, 3, 0, 1, 1),), ' | ', 1.0, ', ', 1.7777777777777777)
('Check node: ', ((1, 0, 1, 2, 0, 3, 0, 0, 1),), ' | ', 1.5, ', ', 2.3888888888888888)
('Check node: ', ((1, 0, 1, 0, 0, 3, 2, 0, 1),), ' | ', 1.5, ', ', 2.3888888888888888)
('Check node: ', ((1, 0, 1, 1, 0, 3, 1, 0, 1),), ' | ', 1.5, ', ', 2.3888888888888888)
('Check node: ', ((1, 0, 1, 1, 0, 3, 1, 0, 1),), ' | ', 1.5, ', ', 2.3888888888888888)
('Check node: ', ((1, 0, 1, 1, 0, 4, 0, 0, 1),), ' | ', 1.75, ', ', 2.6388888888888888)
('Chec

KeyboardInterrupt: 

In [500]:
[(k,ag[k]) for k in gift_types]

[('ball', 1078),
 ('bike', 500),
 ('blocks', 0),
 ('book', 867),
 ('coal', 166),
 ('doll', 664),
 ('gloves', 89),
 ('horse', 1000),
 ('train', 999)]

In [246]:
len(total_state), total_state

(0, [])

In [269]:
score(total_state), score(total_state) * (total_n_bags) / (n_bags * counter)

(30920.446246334373, 35572.765198530171)

In [888]:
def to_submission(state, available_gifts, gift_types):
    n_gifts = [available_gifts[t] for t in gift_types]
    output = []
    for bag in state:
        o = []
        for index, count in enumerate(bag):   
            gift_type = gift_types[index]
            for i in range(count):
                v = n_gifts[index] - 1
                assert v >= 0, "Gift index is negative"
                o.append(gift_type + '_%i' % v)
                n_gifts[index] -= 1
        output.append(o)  
    return output
        
submission = to_submission(total_state, available_gifts, gift_types)
# print(submission)

In [886]:
from datetime import datetime
submission_file = '../results/submission_' + \
                  str(datetime.now().strftime("%Y-%m-%d-%H-%M")) + \
                  '.csv'

In [887]:
def write_submission(state, filename):
    with open(filename, 'w') as w:
        w.write("Gifts\n")
        for bag in state:
            w.write(' '.join(bag) + '\n')
    
write_submission(submission, submission_file)