In [1]:
import maude
import random
import numpy as np
import numpy.ma as ma
import tqdm
from tqdm.notebook import tqdm

In [2]:
maude.init()
maude.load('./sliding-tiles.maude')
m = maude.getCurrentModule()
print('Using', m, 'module')

Using TEST module
reduce in TEST : size .
rewrites: 1 in 0ms cpu (0ms real) (~ rewrites/second)
result NzNat: 3


['up', 'right', 'down', 'left']


In [9]:
def st_generator(path=None, n_shuffles=100):
    N = 3 # tiles = N x N
    tiles = []
    for i in range(N):
        row = []
        for j in range(N):
            row.append(N*i+j)
        tiles.append(row)
    i0, j0 = 0, 0
    cnt = 0
    while cnt < n_shuffles:
        action = random.choice(range(4))
        if action == 0 and i0 - 1 >= 0: # up
            tiles[i0][j0], tiles[i0 - 1][j0] = tiles[i0 - 1][j0], tiles[i0][j0]
            i0 = i0 - 1
            cnt += 1
            if path != None:
                path.append('up')
            #print('up')
        elif action == 1 and j0 + 1 < N: # right
            tiles[i0][j0], tiles[i0][j0 + 1] = tiles[i0][j0 + 1], tiles[i0][j0]
            j0 = j0 + 1
            cnt += 1
            if path != None:
                path.append('right')
            #print('right')
        elif action == 2 and i0 + 1 < N: # down
            tiles[i0][j0], tiles[i0 + 1][j0] = tiles[i0 + 1][j0], tiles[i0][j0]
            i0 = i0 + 1
            cnt += 1
            if path != None:
                path.append('down')            
            #print('down')
        elif action == 3 and j0 - 1 >= 0: # left
            tiles[i0][j0], tiles[i0][j0 - 1] = tiles[i0][j0 - 1], tiles[i0][j0]
            j0 = j0 - 1
            cnt += 1
            if path != None:
                path.append('left')
            #print('left')
    #print(tiles)
    tiles = [f'tile({n},{i},{j})' for (i,row) in enumerate(tiles) for (j,n) in enumerate(row)]
    return m.parseTerm(' '.join(tiles))

In [47]:
class MaudeEnv():
    def __init__(self, g=None, init_term=None):
        self.conf_gen = g
        self.rules = []
        for rl in m.getRules():
            if not rl.getLabel() == None:
                self.rules.append(rl.getLabel())
        self.reset(init_term)
            
    def reset(self, init_term=None):
        if init_term == None:
            t = self.conf_gen()
        else:
            t = init_term
        self.term = t
        self.acfg = self.get_acfg(t)
        self.actions = [rl for rl in rules if any(True for _ in t.apply(rl))]
        #self.nbrs = [(t,self.get_acfg(t)) for t,_,_,_ in t.search(1, m.parseTerm('X:Conf'), depth = 1)]  
        #self.next_acfg = list(set([action for (_,action) in self.nbrs])) # remove duplicates
        return self.get_state() 
    
    def get_state(self):
        return {
            'term' : self.term,
            'acfg' : self.acfg,
            'actions' : self.actions,
        }
        
    def step(self, action):
        #pairs = [(term, acfg) for (term,acfg) in self.nbrs if acfg == action]
        nbrs = [next_t for next_t, _, _, _ in self.term.apply(action)]
        if nbrs == []:
            raise Exception("invalid action")
        state = self.reset(random.choice(nbrs))
        reward = 1.0 if self.is_goal() else 0.0
        done = True if reward == 1.0 else False # TODO: +done if no rewrite possible
        return state, reward, done
    
    def get_acfg(self, t):
        term_str = ""
        for action in self.rules:
            for r, sb, ctx, rl in t.apply(action):
                term_str += (str(sb.instantiate(rl.getLhs())) + ' ')
        acfg = m.parseTerm(term_str)
        acfg.reduce()
        return (acfg, self.is_goal())
    
    def is_goal(self):
        t = m.parseTerm(f'isGoal({self.term.prettyPrint(0)})')
        t.reduce()
        return t.prettyPrint(0) == 'true'

In [17]:
class QTable():
    def __init__(self):
        self.q_init = 0.0
        self.q_dict = dict()
        
    def get_q(self, s, a):
        q_init = self.q_init
        #row = self.q_dict.get(t1, None)
        #if not row == None:
        if s in self.q_dict:
            return self.q_dict[s].get(a, q_init)
        return q_init
        
    def set_q(self, s, a, q):
        # TODO deepcopy terms
        if q == 0.0: # TODO
            return
        elif not s in self.q_dict:
            self.q_dict[s] = { a : q }
        else:
            self.q_dict[s][a] = q
        
    def argmax_q(self, s, actions): # nbrs: iterable if acfg's
        q_dict = self.q_dict
        if s in q_dict and len(actions) != 0:
            d = { a : q_dict[s].get(a, self.q_init) for a in actions } # d = restriction of q_dict to tl
            return max(d, key=d.get) # FIXME: random choice if tie
        else:
            return -1
        
    def max_q(self, s):
        q_dict = self.q_dict
        if s in q_dict: # assume q_dict[t] is nonempty
            return max(q_dict[s].values())
        return self.q_init
    
    def get_size(self):
        # returns the number of nonzero entries in the QTable
        ret = 0
        for _, d in self.q_dict.items():
            ret += len(d)
        return ret
    
    def print_v(self):
        q_dict = self.q_dict
        print(f'fmod SCORE is')
        for t in q_dict:
            print(f'  eq score({t}) = {self.max_q(t)} .')
        print(f'  eq score(X) = {self.q_init} [owise] .')
        print(f'endfm')        
    
    def print_q(self):
        q_dict = self.q_dict
        print('load dp.maude')
        print('mod SCORE is')
        print('  pr DP5 .')
        print('  pr FLOAT .')
        print('  op score : AConf AConf -> Float .')
        for t1, d in q_dict.items():
            for t2, q in d.items():
                print(f'  eq score({t1}, {t2}) = {q} .')
        print(f'  eq score(X:AConf, Y:AConf) = {self.q_init} [owise] .') # TODO: 0 should be printed 0.0
        print(f'endm')
        
    def get_q_closure(self, states, a):
        return max([self.get_q(s,a) for s in states])
    
    def max_q_closure(self, states):
        return max([self.max_q(s) for s in states])
    
    # def argmax_q_closure

In [18]:
def inverse_action(action):
    if action == 'up':
        return 'down'
    elif action == 'down':
        return 'up'
    elif action == 'right':
        return 'left'
    elif action == 'left':
        return 'right'

def generate_episode():
    path = []
    t = st_generator(path)
    path = list(map(inverse_action, reversed(path)))
    return MaudeEnv(None, t), path

In [48]:
def pre_train(n_training_episodes, max_steps, qt):
    stat = 0
    for episode in tqdm(range(n_training_episodes)):
        env, path = generate_episode()
        state = env.get_state()
        #print('path:', path)
        
        for a in path:
            s = state['acfg']
            #print('s:',s)
            new_state, reward, done = env.step(a)
            ns = new_state['acfg']
            stat += reward

            # Update Q(s,a):= Q(s,a) + lr [R(s,a) + gamma * max Q(s',a') - Q(s,a)]    
            new_q = qt.get_q(s, a) + learning_rate * (
                reward + gamma * qt.max_q(ns) - qt.get_q(s, a)
            )
            #if new_q > 1.0:
            #    print('new_q:', new_q)
            #    print('gamma:', gamma)
            #    print('reward:', reward)
            #    print('qt.max_q:', qt.max_q(ns))
            qt.set_q(s, a, new_q)

            # If terminated or truncated finish the episode
            if done:
                break

            # Our next state is the new state
            state = new_state
    print('training done!')
    return qt, stat

In [49]:
n_training_episodes = 100
max_steps = 1000
learning_rate = 0.7  # Learning rate
gamma = 0.95  # Discounting rate

preQtable = QTable()
preQtable, stat = pre_train(n_training_episodes, max_steps, preQtable)

  0%|          | 0/100 [00:00<?, ?it/s]

training done!


In [50]:
print(stat, preQtable.get_size())

100.0 1511


In [51]:
preQtable.q_dict

{(tile(0, 1, 0) tile(3, 0, 0) tile(4, 1, 1) tile(6, 2, 0),
  False): {'up': 0.9999801927669164, 'down': 0.6352929493126704, 'right': 0.8250804406064302},
 (tile(0, 0, 1) tile(1, 0, 0) tile(2, 0, 2) tile(4, 1, 1),
  False): {'left': 0.9932486469340529, 'right': 0.7655491614466073, 'down': 0.7898354103087897},
 (tile(0, 2, 0) tile(6, 1, 0) tile(7, 2, 1), False): {'up': 0.8559697445770813,
  'right': 0.5698846462636481},
 (tile(0, 0, 2) tile(2, 0, 1) tile(6, 1, 2),
  False): {'left': 0.36249132183774196},
 (tile(0, 0, 0) tile(1, 0, 1) tile(7, 1, 0),
  False): {'right': 0.38241444139287945, 'down': 0.21767283849640237},
 (tile(0, 1, 1) tile(3, 1, 0) tile(4, 0, 1) tile(5, 1, 2) tile(7, 2, 1),
  False): {'up': 0.8102122854781797,
  'left': 0.7397970018568745,
  'down': 0.6861438558231258,
  'right': 0.7661437208723332},
 (tile(0, 0, 2) tile(2, 0, 1) tile(5, 1, 2),
  False): {'left': 0.9007909539391603, 'down': 0.7082901168370295},
 (tile(0, 0, 1) tile(1, 0, 0) tile(2, 0, 2) tile(7, 1, 1),
  

In [63]:
rules = []
for rl in m.getRules():
    if not rl.getLabel() == None:
        rules.append(rl.getLabel())
print(rules)

def get_nbrs(t):
    #returns (next_t, a) where a = action applied to produce next_t from t
    return [(t, path()[1].getLabel()) for t, subs, path, nrew in t.search(1, m.parseTerm('X:Conf'), depth = 1)]  
#print(get_nbrs(t))

def get_acfg(t):
    term_str = ""
    for action in rules:
        for r, sb, ctx, rl in t.apply(action):
            term_str += (str(sb.instantiate(rl.getLhs())) + ' ')
    acfg = m.parseTerm(term_str)
    acfg.reduce()
    return (acfg, is_goal(t))

def is_goal(t):
    t = m.parseTerm(f'isGoal({t.prettyPrint(0)})')
    t.reduce()
    return t.prettyPrint(0) == 'true'
#print(is_goal(t))

def get_k_closure(s,k):
    return [s] + [t for t, subs, path, nrew in s.search(1, m.parseTerm('X:AConf'), depth = k)]  

['up', 'right', 'down', 'left']


In [64]:
import heapq

class TermWrapper():
    def __init__(self,t):
        self.t = t
        
    def __lt__(self, other):
        return 0

# BFS with heuristics given by abstract Q table
def search(t, mode=1, max_step=10000, qt=None):
    #if not qt is None:
    #    heuristics = True
    #else:
    #    heuristics = False
    #t = st_generator(n_shuffles=300) #TODO move to argument
    visited = set()
    i = 0
    queue = [(i,TermWrapper(t))] # (priority, concrete_state)
    
    while not queue == [] and i < max_step:
        t = heapq.heappop(queue)[1].t
        #print(t)
        if t in visited:
            continue
        i += 1
        visited.add(t)
        s = get_acfg(t)
        if is_goal(t):
            #print('goal reached!')
            #print('t:', t)
            #print('num steps:', i)
            break
        # nbrs = [(v, av) for (a, v, av) in env.next_actions if not v in visited] # unvisited next vecs
        if mode == 0: # bfs
            p_nbrs = [(i, TermWrapper(nt)) for (nt, a) in get_nbrs(t)]
            #p_nbrs = [(i, v) for (v, av) in nbrs] # prioritized nbrs
        elif mode == 1: # qhs
            p_nbrs = [(-qt.get_q(s, a), TermWrapper(nt)) for (nt, a) in get_nbrs(t)] # prioritized nbrs
            #print(p_nbrs)
        elif mode == 2 : # qhs with closure
            p_nbrs = [(-qt.get_q_closure(get_k_closure(s,5), a), TermWrapper(nt)) for (nt, a) in get_nbrs(t)] # prioritized nbrs
        #print(p_nbrs)
        for item in p_nbrs:
            heapq.heappush(queue, item) # queue,item
    return i

In [65]:
emptyQtable = QTable()

In [68]:
# evaluation of search with various modes
N = 3
stat_1 = []
stat_2 = []
i = 0
while i < N:
    print(f'test {i}')
    t = st_generator(n_shuffles=300)
    #search(t, 0,100000) # very slow!
    stat_1.append(search(t, 1, 100000, preQtable)) # qhs
    stat_2.append(search(t, 0, 100000)) # bfs
    #stat_2.append(search(t, 1, 100000, emptyQtable))
    i += 1
print('=== stat 1 ===')
print(stat_1)
print('=== stat 2 ===')
print(stat_2)

test 0
test 1
test 2
=== stat 1 ===
[81, 169, 35]
=== stat 2 ===
[26666, 100000, 10975]


In [69]:
s1 = np.array(stat_1)
print('stat1 mean:', np.mean(s1))
print('stat1 std:', np.std(s1))
s2 = np.array(stat_2)
print('stat2 mean:', np.mean(s2))
print('stat2 std:', np.std(s2))

stat1 mean: 95.0
stat1 std: 55.59376463837169
stat2 mean: 45880.333333333336
stat2 std: 38800.821385922
