In [1]:
import maude
import random
import numpy as np
import numpy.ma as ma
import tqdm
from tqdm.notebook import tqdm

In [2]:
maude.init()
maude.load('./sliding-tiles.maude')
m = maude.getCurrentModule()
print('Using', m, 'module')

Using TEST module
reduce in TEST : size .
rewrites: 1 in 0ms cpu (0ms real) (~ rewrites/second)
result NzNat: 3


In [24]:
rules = []
for rl in m.getRules():
    if not rl.getLabel() == None:
        rules.append(rl.getLabel())
print(rules)

['up', 'right', 'down', 'left']


In [25]:
def st_generator(path=None, n_shuffles=100):
    N = 3 # tiles = N x N
    tiles = []
    for i in range(N):
        row = []
        for j in range(N):
            row.append(N*i+j)
        tiles.append(row)
    i0, j0 = 0, 0
    cnt = 0
    while cnt < n_shuffles:
        action = random.choice(range(4))
        if action == 0 and i0 - 1 >= 0: # up
            tiles[i0][j0], tiles[i0 - 1][j0] = tiles[i0 - 1][j0], tiles[i0][j0]
            i0 = i0 - 1
            cnt += 1
            if path != None:
                path.append('up')
            #print('up')
        elif action == 1 and j0 + 1 < N: # right
            tiles[i0][j0], tiles[i0][j0 + 1] = tiles[i0][j0 + 1], tiles[i0][j0]
            j0 = j0 + 1
            cnt += 1
            if path != None:
                path.append('right')
            #print('right')
        elif action == 2 and i0 + 1 < N: # down
            tiles[i0][j0], tiles[i0 + 1][j0] = tiles[i0 + 1][j0], tiles[i0][j0]
            i0 = i0 + 1
            cnt += 1
            if path != None:
                path.append('down')            
            #print('down')
        elif action == 3 and j0 - 1 >= 0: # left
            tiles[i0][j0], tiles[i0][j0 - 1] = tiles[i0][j0 - 1], tiles[i0][j0]
            j0 = j0 - 1
            cnt += 1
            if path != None:
                path.append('left')
            #print('left')
    #print(tiles)
    tiles = [f'tile({n},{i},{j})' for (i,row) in enumerate(tiles) for (j,n) in enumerate(row)]
    return m.parseTerm(' '.join(tiles))

In [26]:
t = st_generator(n_shuffles=2)
#print(t)

def get_nbrs(t):
    #returns (next_t, a) where a = action applied to produce next_t from t
    return [(t, path()[1].getLabel()) for t, subs, path, nrew in t.search(1, m.parseTerm('X:Conf'), depth = 1)]  
#print(get_nbrs(t))

def get_acfg(t):
    acfg = m.parseTerm('alpha(' + t.prettyPrint(0) + ')') # TODO change aconf to acfg
    acfg.reduce()
    return acfg
#print(get_acfg(t))

def is_goal(t):
    t = m.parseTerm(f'isGoal({t.prettyPrint(0)})')
    t.reduce()
    return t.prettyPrint(0) == 'true'
#print(is_goal(t))

def get_k_closure(s,k):
    return [s] + [t for t, subs, path, nrew in s.search(1, m.parseTerm('X:AConf'), depth = k)]  

In [6]:
#t = st_generator(n_shuffles=100)
t = m.parseTerm('tile(6, 0, 0) tile(2, 0, 1) tile(8, 0, 2) tile(4, 1, 0) tile(5, 1, 1) tile(1, 1, 2) tile(7, 2, 0) tile(3, 2, 1) tile(0, 2, 2)')
print(t)
s = get_acfg(t)
print(s)
print('k=1')
for ns in get_k_closure(s,1):
    print(ns)
print('k=2')
for ns in get_k_closure(s,2):
    print(ns)

#init_term = m.parseTerm('tile(1, 0, 0) tile(0, 0, 1) tile(2, 0, 2) tile(3, 1, 0) tile(4, 1, 1) tile(5, 1, 2) tile(6, 2, 0) tile(7, 2, 1) tile(8, 2, 2)')
#env = MaudeEnv(st_generator,init_term)
#print(env.actions)
#env.step('left')

tile(6, 0, 0) tile(2, 0, 1) tile(8, 0, 2) tile(4, 1, 0) tile(5, 1, 1) tile(1, 1, 2) tile(7, 2, 0) tile(3, 2, 1) tile(0, 2, 2)
tile(zero, 2, 2) tile(right, 0, 1) tile(right, 1, 0) tile(right, 1, 1) tile(right, 2, 0) tile(down, 0, 0) tile(down, 0, 2) tile(upLeft, 1, 2) tile(upLeft, 2, 1)
k=1
tile(zero, 2, 2) tile(right, 0, 1) tile(right, 1, 0) tile(right, 1, 1) tile(right, 2, 0) tile(down, 0, 0) tile(down, 0, 2) tile(upLeft, 1, 2) tile(upLeft, 2, 1)
tile(zero, 2, 2) tile(up, 1, 2) tile(right, 0, 1) tile(right, 1, 0) tile(right, 1, 1) tile(right, 2, 0) tile(down, 0, 0) tile(down, 0, 2) tile(upLeft, 2, 1)
tile(zero, 2, 2) tile(right, 0, 1) tile(right, 1, 0) tile(right, 1, 1) tile(right, 2, 0) tile(down, 0, 0) tile(down, 0, 2) tile(left, 1, 2) tile(upLeft, 2, 1)
tile(zero, 2, 2) tile(up, 2, 1) tile(right, 0, 1) tile(right, 1, 0) tile(right, 1, 1) tile(right, 2, 0) tile(down, 0, 0) tile(down, 0, 2) tile(upLeft, 1, 2)
tile(zero, 2, 2) tile(right, 0, 1) tile(right, 1, 0) tile(right, 1, 1) tile

In [161]:
path = []
t = st_generator(path)
print(t)
v = m.parseTerm('X:Conf')
for rl in rules:
    print(rl,':',any(True for _ in t.apply(rl)))
[rl for rl in rules if any(True for _ in t.apply(rl))]
#t.apply('up')
#t.search(1,v,depth=1)
#for r, sb, ctx, rl in t.apply('up'):
#print(r, 'with', rl.getLabel()) #, 'in context', ctx(v))

tile(5, 0, 0) tile(3, 0, 1) tile(7, 0, 2) tile(1, 1, 0) tile(8, 1, 1) tile(6, 1, 2) tile(4, 2, 0) tile(2, 2, 1) tile(0, 2, 2)
up : True
right : False
down : False
left : True


['up', 'left']

In [23]:
t = m.parseTerm('tile(1, 0, 0) tile(6, 0, 1) tile(2, 0, 2) tile(8, 1, 0) tile(0, 1, 1) tile(7, 1, 2) tile(3, 2, 0) tile(5, 2, 1) tile(4, 2, 2)')
v = m.parseTerm('C:Conf')
print('term:', t)

term_str = ""
for action in rules:
    for r, sb, ctx, rl in t.apply(action):
        print(f'-------{action}--------')
        term_str += (str(sb.instantiate(rl.getLhs())) + ' ')
        #print(r, 'with', rl, 'in context', ctx(v))
        
print(term_str)

term: tile(1, 0, 0) tile(6, 0, 1) tile(2, 0, 2) tile(8, 1, 0) tile(0, 1, 1) tile(7, 1, 2) tile(3, 2, 0) tile(5, 2, 1) tile(4, 2, 2)
-------up--------
-------right--------
-------down--------
-------left--------
tile(0, 1, 1) tile(6, 0, 1) tile(0, 1, 1) tile(7, 1, 2) tile(0, 1, 1) tile(5, 2, 1) tile(0, 1, 1) tile(8, 1, 0) 


In [27]:
class MaudeEnv():
    def __init__(self, g=None, init_term=None):
        self.conf_gen = g
        self.reset(init_term)
        self.rules = []
        for rl in m.getRules():
            if not rl.getLabel() == None:
                self.rules.append(rl.getLabel())
            
    def reset(self, init_term=None):
        if init_term == None:
            t = self.conf_gen()
        else:
            t = init_term
        self.term = t
        self.acfg = self.get_acfg(t)
        self.actions = [rl for rl in rules if any(True for _ in t.apply(rl))]
        #self.nbrs = [(t,self.get_acfg(t)) for t,_,_,_ in t.search(1, m.parseTerm('X:Conf'), depth = 1)]  
        #self.next_acfg = list(set([action for (_,action) in self.nbrs])) # remove duplicates
        return self.get_state() 
    
    def get_state(self):
        return {
            'term' : self.term,
            'acfg' : self.acfg,
            'actions' : self.actions,
        }
        
    def step(self, action):
        #pairs = [(term, acfg) for (term,acfg) in self.nbrs if acfg == action]
        nbrs = [next_t for next_t, _, _, _ in self.term.apply(action)]
        if nbrs == []:
            raise Exception("invalid action")
        state = self.reset(random.choice(nbrs))
        reward = 1.0 if self.is_goal() else 0.0
        done = True if reward == 1.0 else False # TODO: +done if no rewrite possible
        return state, reward, done
    
    def get_acfg(self, t):
        acfg = m.parseTerm('alpha(' + t.prettyPrint(0) + ')') # TODO change aconf to acfg
        acfg.reduce()
        return acfg
    
    def is_goal(self):
        t = m.parseTerm(f'isGoal({self.term.prettyPrint(0)})')
        t.reduce()
        return t.prettyPrint(0) == 'true'

In [71]:
env = MaudeEnv(st_generator)

In [28]:
class QTable():
    def __init__(self):
        self.q_init = 0.0
        self.q_dict = dict()
        
    def get_q(self, s, a):
        q_init = self.q_init
        #row = self.q_dict.get(t1, None)
        #if not row == None:
        if s in self.q_dict:
            return self.q_dict[s].get(a, q_init)
        return q_init
        
    def set_q(self, s, a, q):
        # TODO deepcopy terms
        if q == 0.0: # TODO
            return
        elif not s in self.q_dict:
            self.q_dict[s] = { a : q }
        else:
            self.q_dict[s][a] = q
        
    def argmax_q(self, s, actions): # nbrs: iterable if acfg's
        q_dict = self.q_dict
        if s in q_dict and len(actions) != 0:
            d = { a : q_dict[s].get(a, self.q_init) for a in actions } # d = restriction of q_dict to tl
            return max(d, key=d.get) # FIXME: random choice if tie
        else:
            return -1
        
    def max_q(self, s):
        q_dict = self.q_dict
        if s in q_dict: # assume q_dict[t] is nonempty
            return max(q_dict[s].values())
        return self.q_init
    
    def get_size(self):
        # returns the number of nonzero entries in the QTable
        ret = 0
        for _, d in self.q_dict.items():
            ret += len(d)
        return ret
    
    def print_v(self):
        q_dict = self.q_dict
        print(f'fmod SCORE is')
        for t in q_dict:
            print(f'  eq score({t}) = {self.max_q(t)} .')
        print(f'  eq score(X) = {self.q_init} [owise] .')
        print(f'endfm')        
    
    def print_q(self):
        q_dict = self.q_dict
        print('load dp.maude')
        print('mod SCORE is')
        print('  pr DP5 .')
        print('  pr FLOAT .')
        print('  op score : AConf AConf -> Float .')
        for t1, d in q_dict.items():
            for t2, q in d.items():
                print(f'  eq score({t1}, {t2}) = {q} .')
        print(f'  eq score(X:AConf, Y:AConf) = {self.q_init} [owise] .') # TODO: 0 should be printed 0.0
        print(f'endm')
        
    def get_q_closure(self, states, a):
        return max([self.get_q(s,a) for s in states])
    
    def max_q_closure(self, states):
        return max([self.max_q(s) for s in states])
    
    # def argmax_q_closure

In [29]:

#print(path)
#path.reverse()
def inverse_action(action):
    if action == 'up':
        return 'down'
    elif action == 'down':
        return 'up'
    elif action == 'right':
        return 'left'
    elif action == 'left':
        return 'right'

def generate_episode():
    path = []
    t = st_generator(path)
    path = list(map(inverse_action, reversed(path)))
    return MaudeEnv(None, t), path

env, path = generate_episode()
print(path)
for action in path:
    state, reward, done = env.step(action)
    print(reward)

['left', 'right', 'left', 'down', 'down', 'left', 'right', 'left', 'right', 'left', 'up', 'down', 'up', 'right', 'right', 'left', 'left', 'up', 'right', 'right', 'down', 'up', 'left', 'right', 'down', 'down', 'up', 'down', 'up', 'up', 'down', 'up', 'left', 'right', 'left', 'down', 'down', 'up', 'down', 'right', 'left', 'left', 'up', 'down', 'up', 'down', 'up', 'down', 'up', 'right', 'down', 'right', 'up', 'down', 'up', 'left', 'up', 'down', 'right', 'down', 'up', 'left', 'left', 'up', 'right', 'left', 'right', 'down', 'left', 'right', 'right', 'down', 'up', 'up', 'left', 'left', 'right', 'down', 'up', 'down', 'up', 'right', 'down', 'up', 'left', 'left', 'down', 'up', 'down', 'right', 'up', 'down', 'left', 'down', 'right', 'up', 'left', 'up', 'right', 'left']
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0

In [30]:
def pre_train(n_training_episodes, max_steps, qt):
    stat = 0
    for episode in tqdm(range(n_training_episodes)):
        env, path = generate_episode()
        state = env.get_state()
        #print('path:', path)
        
        for a in path:
            s = state['acfg']
            #print('s:',s)
            new_state, reward, done = env.step(a)
            ns = new_state['acfg']
            stat += reward

            # Update Q(s,a):= Q(s,a) + lr [R(s,a) + gamma * max Q(s',a') - Q(s,a)]    
            new_q = qt.get_q(s, a) + learning_rate * (
                reward + gamma * qt.max_q(ns) - qt.get_q(s, a)
            )
            #print('new_q:', new_q)
            qt.set_q(s, a, new_q)

            # If terminated or truncated finish the episode
            if done:
                break

            # Our next state is the new state
            state = new_state
    print('training done!')
    return qt, stat

def pre_train_closure(n_training_episodes, max_steps, qt):
    stat = 0
    for episode in tqdm(range(n_training_episodes)):
        env, path = generate_episode()
        state = env.get_state()
        #print('path:', path)
        
        for a in path:
            s = state['acfg']
            #print('s:',s)
            new_state, reward, done = env.step(a)
            ns = new_state['acfg']
            stat += reward

            # Update Q(s,a):= Q(s,a) + lr [R(s,a) + gamma * max Q(s',a') - Q(s,a)]    
            new_q = qt.get_q(s, a) + learning_rate * (
                reward + gamma * qt.max_q_closure(get_k_closure(ns,5)) - qt.get_q(s, a)
            )
            #print('new_q:', new_q)
            qt.set_q(s, a, new_q)

            # If terminated or truncated finish the episode
            if done:
                break

            # Our next state is the new state
            state = new_state
    print('training done!')
    return qt, stat

In [31]:
# Training parameters
n_training_episodes = 1000  # Total training episodes
learning_rate = 0.7  # Learning rate

# Evaluation parameters
n_eval_episodes = 1000  # Total number of test episodes

# Environment parameters
#env_id = "FrozenLake-v1"  # Name of the environment
max_steps = 1000  # Max steps per episode
gamma = 0.95  # Discounting rate
eval_seed = []  # The evaluation seed of the environment

# Exploration parameters
max_epsilon = 1.0  # Exploration probability at start
min_epsilon = 0.05  # Minimum exploration probability
decay_rate = 0.0005  # Exponential decay rate for exploration prob

In [32]:
preQtable = QTable()
preQtable, stat = pre_train(600, max_steps, preQtable)

  0%|          | 0/600 [00:00<?, ?it/s]

training done!


In [343]:
preQtableClosure = QTable()
preQtableClosure, stat = pre_train_closure(600, max_steps, preQtableClosure)

  0%|          | 0/600 [00:00<?, ?it/s]

training done!


In [344]:
stat

600.0

In [33]:
preQtable.q_dict

{tile(zero, 0, 1) tile(stay, 0, 2) tile(stay, 1, 0) tile(stay, 1, 1) tile(stay, 1, 2) tile(stay, 2, 0) tile(stay, 2, 1) tile(stay, 2, 2) tile(right, 0, 0): {'left': 1.0,
  'right': 0.9025,
  'down': 0.9025},
 tile(zero, 1, 0) tile(stay, 0, 1) tile(stay, 0, 2) tile(stay, 1, 1) tile(stay, 1, 2) tile(stay, 2, 0) tile(stay, 2, 1) tile(stay, 2, 2) tile(down, 0, 0): {'up': 1.0,
  'down': 0.9025,
  'right': 0.9025},
 tile(zero, 1, 1) tile(stay, 0, 2) tile(stay, 1, 0) tile(stay, 1, 2) tile(stay, 2, 0) tile(stay, 2, 1) tile(stay, 2, 2) tile(right, 0, 0) tile(down, 0, 1): {'up': 0.95,
  'left': 0.8573749999999999,
  'down': 0.8573749999999999,
  'right': 0.8573749999999999},
 tile(zero, 0, 2) tile(stay, 1, 0) tile(stay, 1, 1) tile(stay, 1, 2) tile(stay, 2, 0) tile(stay, 2, 1) tile(stay, 2, 2) tile(right, 0, 0) tile(right, 0, 1): {'left': 0.95,
  'down': 0.8573749999999999},
 tile(zero, 2, 0) tile(stay, 0, 1) tile(stay, 0, 2) tile(stay, 1, 1) tile(stay, 1, 2) tile(stay, 2, 1) tile(stay, 2, 2) til

In [34]:
preQtable.get_size()

2385

In [345]:
preQtableClosure.get_size()

4314

In [317]:
t = st_generator(n_shuffles=50)
s = get_acfg(t)
#s = m.parseTerm('tile(zero, 0, 0) tile(stay, 0, 2) tile(stay, 1, 2) tile(stay, 2, 2) tile(up, 1, 1) tile(left, 2, 1) tile(upRight, 2, 0) tile(downLeft, 0, 1) tile(downRight, 1, 0)')
#s.reduce()
print('q:', Qtable.get_q(s,'right'))
states = get_k_closure(s,5)
print(len(states))
print('q_closure:', Qtable.get_q_closure(states,'right'))

q: 0.0
81
q_closure: 0.6983369113499565


In [348]:
import heapq

class TermWrapper():
    def __init__(self,t):
        self.t = t
        
    def __lt__(self, other):
        return 0

# BFS with heuristics given by abstract Q table
def search(t, mode=1, max_step=10000, qt=None):
    #if not qt is None:
    #    heuristics = True
    #else:
    #    heuristics = False
    #t = st_generator(n_shuffles=300) #TODO move to argument
    visited = set()
    i = 0
    queue = [(i,TermWrapper(t))] # (priority, concrete_state)
    
    while not queue == [] and i < max_step:
        t = heapq.heappop(queue)[1].t
        #print(t)
        if t in visited:
            continue
        i += 1
        visited.add(t)
        s = get_acfg(t)
        if is_goal(t):
            #print('goal reached!')
            #print('t:', t)
            #print('num steps:', i)
            break
        # nbrs = [(v, av) for (a, v, av) in env.next_actions if not v in visited] # unvisited next vecs
        if mode == 0: # bfs
            p_nbrs = [(i, TermWrapper(nt)) for (nt, a) in get_nbrs(t)]
            #p_nbrs = [(i, v) for (v, av) in nbrs] # prioritized nbrs
        elif mode == 1: # qhs
            p_nbrs = [(-qt.get_q(s, a), TermWrapper(nt)) for (nt, a) in get_nbrs(t)] # prioritized nbrs
            #print(p_nbrs)
        elif mode == 2 : # qhs with closure
            p_nbrs = [(-qt.get_q_closure(get_k_closure(s,5), a), TermWrapper(nt)) for (nt, a) in get_nbrs(t)] # prioritized nbrs
        #print(p_nbrs)
        for item in p_nbrs:
            heapq.heappush(queue, item) # queue,item
    return i

In [361]:
# evaluation of search with various modes
N = 100
stat_1 = []
stat_2 = []
i = 0
while i < N:
    print(f'test {i}')
    t = st_generator(n_shuffles=1000)
    #search(t, 0,100000) # very slow!
    stat_1.append(search(t, 1, 100000, preQtable))
    stat_2.append(search(t, 2, 100000, preQtableClosure))
    i += 1
print('=== stat 1 ===')
print(stat_1)
print('=== stat 2 ===')
print(stat_2)

test 0
test 1
test 2
test 3
test 4
test 5
test 6
test 7
test 8
test 9
test 10
test 11
test 12
test 13
test 14
test 15
test 16
test 17
test 18
test 19
test 20
test 21
test 22
test 23
test 24
test 25
test 26
test 27
test 28
test 29
test 30
test 31
test 32
test 33
test 34
test 35
test 36
test 37
test 38
test 39
test 40
test 41
test 42
test 43
test 44
test 45
test 46
test 47
test 48
test 49
test 50
test 51
test 52
test 53
test 54
test 55
test 56
test 57
test 58
test 59
test 60
test 61
test 62
test 63
test 64
test 65
test 66
test 67
test 68
test 69
test 70
test 71
test 72
test 73
test 74
test 75
test 76
test 77
test 78
test 79
test 80
test 81
test 82
test 83
test 84
test 85
test 86
test 87
test 88
test 89
test 90
test 91
test 92
test 93
test 94
test 95
test 96
test 97
test 98
test 99
=== stat 1 ===
[259, 302, 1575, 115, 117, 751, 3261, 1417, 102, 716, 49, 956, 506, 1707, 1273, 998, 691, 196, 633, 309, 149, 25, 1681, 455, 362, 1197, 58, 136, 1265, 290, 183, 1160, 1413, 511, 118, 570, 231, 40

In [363]:
s1 = np.array(stat_1)
print('stat1 mean:', np.mean(s1))
print('stat1 std:', np.std(s1))
s2 = np.array(stat_2)
print('stat2 mean:', np.mean(s2))
print('stat2 std:', np.std(s2))

stat1 mean: 757.3
stat1 std: 668.5020194434719
stat2 mean: 335.27
stat2 std: 336.34443224171264


In [167]:
def greedy_policy(Qt, state):
    # Exploitation: take the action with the highest state, action value
    t = state["acfg"]
    actions = state["actions"]
    return Qt.argmax_q(t,actions)

def eps_greedy_policy(Qtable, state, epsilon):
    random_num = random.uniform(0, 1)
    if random_num > epsilon: # exploitation
        return greedy_policy(Qtable, state)
    else: # exploration
        actions = state["actions"]
        if len(actions) != 0:
            return random.choice(actions)
        else:
            return -1

        
def train(n_training_episodes, min_epsilon, max_epsilon, decay_rate, env, max_steps, qt):
    stat = 0
    for episode in tqdm(range(n_training_episodes)):
        # Reduce epsilon (because we need less and less exploration)
        epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(-decay_rate * episode)
        # Reset the environment
        state = env.reset()
        step = 0
        done = False

        # repeat
        for step in range(max_steps):
            # Choose the action At using epsilon greedy policy
            s = state["acfg"]
            a = eps_greedy_policy(qt, state, epsilon)
            
            # assert action not -1
            if type(a) == type(-1):
                break

            # Take action At and observe Rt+1 and St+1
            # Take the action (a) and observe the outcome state(s') and reward (r)
            #print('episode:', episode, 'step:', step, 'a:',a)
            new_state, reward, done = env.step(a)
            ns = new_state['acfg']
            stat += reward

            # Update Q(s,a):= Q(s,a) + lr [R(s,a) + gamma * max Q(s',a') - Q(s,a)]

            #Qtable[s][a] = Qtable[s][a] + learning_rate * (
            #    reward + gamma * np.max(Qtable[a]) - Qtable[s][a]
            #)
            
            new_q = qt.get_q(s, a) + learning_rate * (
                reward + gamma * qt.max_q(ns) - qt.get_q(s, a) # FIXME!!!!! max_q(s')!!!!
            )
            qt.set_q(s, a, new_q)

            # If terminated or truncated finish the episode
            if done:
                break

            # Our next state is the new state
            state = new_state
    print('training done!')
    return qt, stat

In [219]:
# train Qtable
env = MaudeEnv(st_generator)

Qtable = QTable()

print("=== Pre-Training ===")
Qtable, stat = pre_train(300, max_steps, Qtable)
print('stat:', stat)
print('Q-size:', Qtable.get_size())

print("=== Training ===")
Qtable, stat = train(5000, min_epsilon, max_epsilon, decay_rate, env, max_steps, Qtable)
print('stat:', stat)
print('Q-size:', Qtable.get_size())

=== Pre-Training ===


  0%|          | 0/300 [00:00<?, ?it/s]

training done!
stat: 300.0
Q-size: 1261
=== Training ===


  0%|          | 0/5000 [00:00<?, ?it/s]

training done!
stat: 287.0
Q-size: 1432


In [173]:
#import copy
#tmp_Qtable = copy.deepcopy(Qtable)

In [218]:
Qtable.get_size()

1480

In [179]:
tmp_Qtable.get_size()

0

In [None]:
class MaudeEnv():
    def __init__(self, g=None, init_term=None):
        self.conf_gen = g
        self.reset(init_term)
        
    def reset(self, init_term=None):
        if init_term == None:
            t = self.conf_gen()
        else:
            t = init_term
        self.term = t
        self.acfg = self.get_acfg(t)
        self.nbrs = [(t,self.get_acfg(t)) for t,_,_,_ in t.search(1, m.parseTerm('X:Conf'), depth = 1)]  
        self.next_acfg = list(set([action for (_,action) in self.nbrs])) # remove duplicates
        return self.get_state() 
    
    def get_state(self):
        return {
            'term' : self.term,
            'acfg' : self.acfg,
            'nbrs' : self.nbrs,
            'next_acfg' : self.next_acfg,
        }
        
    def step(self, action):
        pairs = [(term, acfg) for (term,acfg) in self.nbrs if acfg == action]
        if pairs == []:
            raise Exception("invalid action")
        state = self.reset(random.choice(pairs)[0])
        reward = 1 if self.is_goal() else 0
        done = True if reward == 1 or self.nbrs == [] else False
        return state, reward, done
    
    def get_acfg(self, t):
        acfg = m.parseTerm('alpha(' + t.prettyPrint(0) + ')') # TODO change aconf to acfg
        acfg.reduce()
        return acfg
    
    def is_goal(self):
        t = m.parseTerm(f'isGoal({self.term.prettyPrint(0)})')
        t.reduce()
        return t.prettyPrint(0) == 'true'