In [35]:
import random
import numpy as np
import collections
from game import Board, Game
from player_random import PlayerRandom
import sys
import uuid
import inspect
import copy
import anytree
from anytree import Node as anytreeNode
from anytree import RenderTree, PreOrderIter
from anytree.exporter import DotExporter
from enum import Enum
import functools
import tqdm
from tqdm.notebook import tqdm as tqdm_nb
import itertools
from vanilla_uct_player import VanillaUCTPlayer
from pprint import pprint
from bokeh.io import output_file, show, output_notebook, reset_output, export_png
from bokeh.models import ColumnDataSource
from bokeh.plotting import figure
import pickle

In [3]:
def get_node_id():
    return str(uuid.uuid4())[:8]

In [4]:
class Node(anytreeNode):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._id = get_node_id()
        
    def __eq__(self, other):
        return RenderTree(self).by_attr() == RenderTree(other).by_attr()
    
    def __hash__(self):
        return str(RenderTree(self).by_attr()).__hash__()

In [5]:
class Script(object):
    def __init__(self, tree):
        super().__init__()
        self.tree = tree
        
    def get_action(self):
        raise NotImplementedError
        
class Lib(object):
    def is_doubles(action):
        if len(action) > 1 and action[0] == action[1]:
            return True
        else:
            return False

    def contains_number(action, column_num):
        if not isinstance(action, str):
            if column_num in action:
                return True
        return False

    def has_won_column(state, action):
        return len(state.columns_won_current_round()) > 0
    
    def column_progression_this_round_greater_than(state, column_num, small_num):
        return state.number_positions_conquered_this_round(column_num) >= small_num

    def column_progression_greater_than(state, column_num, small_num):
        return state.number_positions_conquered(column_num) >= small_num
    
    def progression_this_round_greater_than(state, small_num):
        progression = sum(state.number_positions_conquered_this_round(i) for i in range(2, 7))
        return progression >= small_num
    
    def progression_greater_than(state, small_num):
        progression = sum(state.number_positions_conquered(i) for i in range(2, 7))
        return progression >= small_num

    def is_yes_action(action):
        if isinstance(action, str) and action == 'y':
            return True
        return False
    
    def is_no_action(action):
        if isinstance(action, str) and action == 'n':
            return True
        return False

In [41]:
default_enums = [
    "START",
    "IF_BLOCK",
    "IF_BODY",
    "BOOL_EXP",
    "AND_EXP",
    "OR_EXP",
    "NOT_EXP",
    "BOOL",
    "RETURN",
    "FUNC_CALL",
    "COLUMN_NUM",
    "SMALL_NUM",
]
    
lib_functions = inspect.getmembers(Lib, inspect.isfunction)

In [42]:
lib_func_names = [
    name.upper()
    for name, _ in lib_functions if name[0] != '_'
]

Rule = Enum('Rule', default_enums + lib_func_names)

In [43]:
def get_func(name):
    return dict(lib_functions)[name.lower()]

def get_params(f):
    return list(inspect.signature(f).parameters)

In [44]:
def make_dynamic_rule(name):
    func = get_func(name)
    params = get_params(func)
    def _convert(param):
        if param == 'state':
            return 'state'
        elif param == 'action':
            return 'a'
        elif param == 'column_num':
            return Rule.COLUMN_NUM
        elif param == 'small_num':
            return Rule.SMALL_NUM
        elif param == 'self':
            return 'self'
        else:
            raise ValueError
    return ['Lib.' + name.lower(), *list(map(_convert, params))]

In [45]:
class Sampler(object):
    pass

class Diminishing(Sampler):
    def __init__(self, gamma, rule):
        self.gamma = gamma
        self.rule = rule
        
    def sample(self):
        ret = []
        curr = 1
        while random.random() <= curr:
            ret.append(self.rule)
            curr *= self.gamma
        return ret
    
class Weighted(Sampler):
    def __init__(self, dict_):
        self.dict = dict_
    
    def sample(self):
        weight_sum = sum(self.dict.keys())
        normal_quoefficient = 1 / weight_sum
        rand = random.random()
        for weight, rule in self.dict.items():
            prob = weight * normal_quoefficient
            if rand <= prob:
                if not isinstance(rule, list):
                    rule = [rule]
                return rule
            else:
                rand -= prob
        raise ValueError

In [46]:
grammar = {
    Rule.START: [
        Rule.IF_BLOCK,
    ],
    Rule.IF_BLOCK: (
        [Rule.BOOL_EXP, Rule.IF_BODY],
    ),
    Rule.IF_BODY: [
        Rule.RETURN,
    ],
    Rule.BOOL_EXP: [
#         Rule.BOOL_EXP,
        Rule.BOOL, 
        Rule.AND_EXP,
        Rule.NOT_EXP,
        Rule.OR_EXP,
    ],
    Rule.AND_EXP: [
#         Weighted({
#             7: [Rule.BOOL, Rule.BOOL],
#             3: [Rule.BOOL_EXP, Rule.BOOL_EXP]
#         })
        [Rule.BOOL, Rule.BOOL],
    ],
    Rule.OR_EXP: [
#         Weighted({
#             7: [Rule.BOOL, Rule.BOOL],
#             3: [Rule.BOOL_EXP, Rule.BOOL_EXP]
#         })
        [Rule.BOOL, Rule.BOOL],
    ],
    Rule.NOT_EXP: [
        Rule.BOOL,
    ],
    Rule.BOOL: (
        Rule.FUNC_CALL,
    ),
    Rule.FUNC_CALL: [
        make_dynamic_rule(name) for name in lib_func_names
    ],
    Rule.RETURN: (
        "return a",
    ),
    Rule.COLUMN_NUM: [
        '2', '3', '4', '5', '6'
    ],
    Rule.SMALL_NUM: (
        '0', '1', '2', '3'
    )
}
grammar

{<Rule.START: 1>: [<Rule.IF_BLOCK: 2>],
 <Rule.IF_BLOCK: 2>: ([<Rule.BOOL_EXP: 4>, <Rule.IF_BODY: 3>],),
 <Rule.IF_BODY: 3>: [<Rule.RETURN: 9>],
 <Rule.BOOL_EXP: 4>: [<Rule.BOOL: 8>,
  <Rule.AND_EXP: 5>,
  <Rule.NOT_EXP: 7>,
  <Rule.OR_EXP: 6>],
 <Rule.AND_EXP: 5>: [[<Rule.BOOL: 8>, <Rule.BOOL: 8>]],
 <Rule.OR_EXP: 6>: [[<Rule.BOOL: 8>, <Rule.BOOL: 8>]],
 <Rule.NOT_EXP: 7>: [<Rule.BOOL: 8>],
 <Rule.BOOL: 8>: (<Rule.FUNC_CALL: 10>,),
 <Rule.FUNC_CALL: 10>: [['Lib.column_progression_greater_than',
   'state',
   <Rule.COLUMN_NUM: 11>,
   <Rule.SMALL_NUM: 12>],
  ['Lib.column_progression_this_round_greater_than',
   'state',
   <Rule.COLUMN_NUM: 11>,
   <Rule.SMALL_NUM: 12>],
  ['Lib.contains_number', 'a', <Rule.COLUMN_NUM: 11>],
  ['Lib.has_won_column', 'state', 'a'],
  ['Lib.is_doubles', 'a'],
  ['Lib.is_no_action', 'a'],
  ['Lib.is_yes_action', 'a'],
  ['Lib.progression_greater_than', 'state', <Rule.SMALL_NUM: 12>],
  ['Lib.progression_this_round_greater_than', 'state', <Rule.SMALL_NUM

In [47]:
def generate_tree(root):
    if isinstance(root.name, str):
        return root
    next_ = grammar.get(root.name, None)
    if not next_:
        return root
    
    branch = random.choice(next_)
    
    candidates = [] 
    if isinstance(branch, list):
        candidates = branch
    elif isinstance(branch, (Rule, str)):
        candidates = [branch]
    elif isinstance(branch, Sampler):
        candidates.extend(branch.sample())
    elif branch is None:
        pass
    else:
        raise ValueError
        
    for cand in candidates:
        child = Node(cand, parent=root)
        generate_tree(child)
    return root

In [48]:
def get_random_tree(seed=None):
    if seed:
        random.seed(seed)
    root = Node(Rule.START)
    tree = generate_tree(root)
    return tree

def print_tree(tree):
    print(RenderTree(tree, style=anytree.render.AsciiStyle()).by_attr())

In [49]:
def indent(raw, level):
    tab = '    '
    lines = raw.splitlines()
    lines = [tab * level + line for line in lines]
    return '\n'.join(lines)

In [53]:
def render(node):
    if isinstance(node.name, str):
        return node.name
    elif isinstance(node.name, Rule):
        if node.name == Rule.IF_BLOCK:
            template = "if ({0}):\n{1}\n"
            bool_exp = render(node.children[0])
            body = indent(render(node.children[1]), 1)
            return template.format(bool_exp, body)
        elif node.name == Rule.AND_EXP:
            template = "({0} and {1})"
            left = render(node.children[0])
            right = render(node.children[1])
            return template.format(left, right)
        elif node.name == Rule.OR_EXP:
            template = "({0} or {1})"
            left = render(node.children[0])
            right = render(node.children[1])
            return template.format(left, right)
        elif node.name == Rule.NOT_EXP:
            template = "not ({0})"
            op = render(node.children[0])
            return template.format(op)
        elif node.name == Rule.FUNC_CALL:
            template = "{0}({1})"
            func_name = render(node.children[0])
            params = ', '.join([render(child) for child in node.children[1:]])
            return template.format(func_name, params)
        return ''.join(render(child) for child in node.children)

In [54]:
script_template = r"""

class {0}(Script):
    def get_action(self, state):
        actions = state.available_moves()
        for a in actions:
{1}
        return actions[0]
"""

def render_script(node):
    script_name = 'Script_' + str(uuid.uuid4()).replace('-', '')
    code = indent(render(node), 3)
    return script_name, script_template.format(script_name, code)

In [55]:
def exec_tree(tree):
    script_name, raw_script = render_script(tree)
    try:
        exec(raw_script)
    except Exception as e:
        print(e)
        print(raw_script)
    script = eval(script_name)
    return script(tree)

In [56]:
def play_game(lhs, rhs):
    game = Game(n_players=2, dice_number=4, dice_value=3, column_range=[2, 6],
                offset=2, initial_height=1)
    
    is_over = False
    who_won = None

    number_of_moves = 0
    current_player = game.player_turn
    while not is_over:
        moves = game.available_moves()
        if game.is_player_busted(moves):
            if current_player == 1:
                current_player = 2
            else:
                current_player = 1
            continue
        else:
            if game.player_turn == 1:
                chosen_play = lhs.get_action(game)
            else:
                chosen_play = rhs.get_action(game)
            if chosen_play == 'n':
                if current_player == 1:
                    current_player = 2
                else:
                    current_player = 1
            game.play(chosen_play)
            number_of_moves += 1

        who_won, is_over = game.is_finished()

        if number_of_moves >= 200:
            is_over = True
            who_won = -1
    return who_won

def evaluate_pair(lhs, rhs, num_games=3):
    for _ in range(num_games):
        result = play_game(lhs, rhs)
        if result == 1:
            winner, losers = lhs, [rhs]
        elif result == 2:
            winner, losers = rhs, [lhs]
        else:
            winner, losers = None, [lhs, rhs]
        if winner:
            winner.fitness += 1
        for loser in losers:
            loser.fitness -= 1

In [57]:
class RecordingUCTPlayer(VanillaUCTPlayer):
    def __init__(self, c, sim):
        super().__init__(c, sim)
        self.records = []
        
    def get_action(self, state):
        action = super().get_action(state)
        self.records.append((state.clone(), action))
        return action

In [58]:
def generate_uct_samples(num_games=1, c=1, simulations=1):
    uct_samples = []
    for i in range(num_games):
        uct = RecordingUCTPlayer(c, simulations)
        uct2 = RecordingUCTPlayer(c, simulations)
        play_game(uct, uct2)
        uct_samples.extend(uct.records)
        uct_samples.extend(uct2.records)
    return uct_samples

In [59]:
def generate_if_block():
    new_if = Node(Rule.IF_BLOCK)
    return generate_tree(new_if)

In [60]:
def is_useful_node(node):
    if not isinstance(node.name, Rule):
        return False
    if not node.children:
        return False
    if node.name in (Rule.IF_BODY, Rule.RETURN):
        return False
    return True

In [61]:
def propose_new_script(script):
    tree = copy.deepcopy(script.tree)
    all_nodes = anytree.search.findall(tree, filter_=is_useful_node)
    random_node = random.choice(all_nodes)
    random_node.children = []
    generate_tree(random_node)
    return exec_tree(tree)

In [62]:
def count_errors(script, uct_samples):
    num_errors = 0
    for state, uct_action in uct_samples:
        script_action = script.get_action(state)
        if uct_action != script_action:
            num_errors += 1
    return num_errors

In [63]:
def get_score(script, uct_samples):
    beta = 0.5
    num_errors = count_errors(script, uct_samples)
    score = np.exp(-beta * num_errors)
    return score

In [64]:
def run_metropolis_hastings(uct_samples, prev_scripts=None, iterations=1000):
    script = exec_tree(get_random_tree())
    scripts = [script]
    for _ in range(iterations):
        last_script = scripts[-1]
        next_script = propose_new_script(last_script)
        
        if prev_scripts:
            merged_last = merge_scripts(prev_scripts + [last_script])
            merged_next = merge_scripts(prev_scripts + [next_script])
        else:
            merged_last = last_script
            merged_next = next_script
        
        last_score = get_score(merged_last, uct_samples)
        next_score = get_score(merged_next, uct_samples)
        
        accept = min(1, next_score / last_score)
        if random.random() < accept:
            scripts.append(next_script)
        else:
            scripts.append(last_script)
    return scripts

In [65]:
def filter_uct_samples(script, uct_samples):
    ret = []
    for state, uct_action in uct_samples:
        script_action = script.get_action(state)
        if uct_action != script_action:
            ret.append((state, uct_action))
    return ret

In [66]:
def merge_scripts(scripts):
    children = []
    for script in scripts:
        children.extend(copy.deepcopy(script.tree.children))
    merged_tree = Node(Rule.START, children=children)
    merged_script = exec_tree(merged_tree)
    return merged_script

In [67]:
def run_synthesis(uct_samples, iterations=10, filter_sample=True, merge=True):
    all_scripts = []
    for _ in range(5):
        prev_scripts = all_scripts if merge else []
        scripts = run_metropolis_hastings(uct_samples, prev_scripts=prev_scripts,
                                          iterations=iterations)
        best_script = sorted(scripts, key=lambda s: count_errors(s, uct_samples))[0]
        if filter_sample:
            uct_samples = filter_uct_samples(best_script, uct_samples)
        all_scripts.append(best_script)
    return all_scripts

In [68]:
def plot_common_clauses(scripts, fname=None):
    counter = collections.Counter()
    for script in scripts:
        for child in script.tree.children:
            counter.update([child])
    reset_output()
#     output_notebook()

    xs = []
    ys = []
    
    for key, val in sorted(counter.most_common(10), key=lambda e: e[1]):
        x = render(key.children[0])
        xs.append(x)
        ys.append(val)

    p = figure(plot_width=1200, plot_height=800, y_range=xs)
    p.hbar(y=xs, right=ys, height=0.5)

    p.x_range.start = 0
    if fname:
        export_png(p, filename=fname)
    return p

In [69]:
def get_uct_samples(num_samples):
    with open('samples.pkl', 'rb') as f:
        uct_samples = pickle.load(f)
        return random.sample(uct_samples, k=num_samples)

In [4]:
def print_errors(multiple_run_scripts):
    sum_errors = 0
    for uct_samples, scripts in multiple_run_scripts:
        try:
            full_script = merge_scripts(scripts)
        except:
            full_script = scripts[0]
        num_errors = count_errors(full_script, uct_samples)
        sum_errors += num_errors
        print('error rate: ',  num_errors / len(uct_samples) * 100, '%')
    print('average error rate:', sum_errors / (len(uct_samples) * len(multiple_run_scripts)) * 100, '%')

In [5]:
def play(lhs, rhs, num_games=1000):
    wins = 0
    for _ in range(num_games):
        if play_game(lhs, rhs) == 1:
            wins += 1
        if play_game(rhs, lhs) == 2:
            wins += 1
    return round(wins / (num_games * 2), 2)

In [6]:
def print_winrates(multiple_run_scripts):
    win_rates = []
    for _, scripts in multiple_run_scripts:
        script = merge_scripts(scripts)
        win_rate = play(script, PlayerRandom())
        win_rates.append(win_rate)
        print('win rate:', round(win_rate * 100, 2), '%')
    print('average win rate:', round((sum(win_rates) / len(win_rates)) * 100, 2), '%')