In [218]:
import gym
import numpy as np
np.set_printoptions(threshold=np.inf)
import matplotlib.pylab as plt
import random

Посмотрим, что за задача такая дана.

In [219]:
env = gym.make("Taxi-v2")

In [220]:
env.action_space.n

6

In [221]:
env.action_space.sample()

3

In [222]:
env.observation_space

Discrete(500)

In [154]:
env.reset()
env.render()

+---------+
|[34;1mR[0m: | : :G|
| : : : : |
| : : : : |
| | : | : |
|[35m[43mY[0m[0m| : |B: |
+---------+



In [167]:
print(env.step(5))
env.render('human')

(418, 20, True, {'prob': 1.0})
+---------+
|R: | : :G|
| : : : : |
| : : : : |
| | : | : |
|[35m[42mY[0m[0m| : |B: |
+---------+
  (Dropoff)


In [163]:
env.step(4)

(118, -10, False, {'prob': 1.0})

In [None]:
env.render()

Попробуем применить Q-learning

In [392]:
class QLearner:

    def __init__(self, env):
        self.env = env
        self.Q = -np.ones((env.observation_space.n, env.action_space.n))
        self.possible_actions = list(range(env.action_space.n))
        self.epsilon = 0.1
        self.alpha = 0.2
        self.gamma = 0.8
        
    def calcRandomAction(self):
        return env.action_space.sample()
    
    def calcBestAction(self, state):
        return np.argmax(self.Q[state])
    
    def calcMaxQ(self, state):
        return np.max(self.Q[state])    
        
    def calcNextAction(self, state):
        if np.max(self.Q[state]) == np.min(self.Q[state]):
            return self.calcRandomAction()
        #if np.max(self.Q[state]) < -0.5:
        #    return self.calcRandomAction()   
        if random.random() < self.epsilon:
            return self.calcRandomAction()
        else:
            return self.calcBestAction(state)
        
    def update(self, state, action, nextState, reward):
        #if reward > 0:
        #    print("update", state, action, nextState, reward)
        oldQ = self.Q[state, action]
        maxQ = self.calcMaxQ(state)
        updatedQ = oldQ + self.alpha * (reward + self.gamma * maxQ - oldQ)
        #if updatedQ > 0:
        #    print("Hallilujah")
        self.Q[state, action] = updatedQ
        
    def cutRepetitives(self, states, actions, rewards):
        lastState = states[-1]
        history = list(zip(states[:-1], actions, rewards))
        lastOcc = {}
        for i in range(len(history)):
            lastOcc[history[i][0]] = i
        i = 0
        nhistory = []
        while i < len(history):
            if i < lastOcc[history[i][0]]:
                i = lastOcc[history[i][0]]
            nhistory.append(history[i])
            i += 1
        states, actions, rewards = map(list, zip(*nhistory))
        states.append(lastState)
        return states, actions, rewards
        
        
        
    def play(self, maxSteps=None):
        env = self.env
        state = env.reset()
        states = [state]
        actions = []
        rewards = []
        #for i in range(maxSteps):
        while(True):
            action = self.calcNextAction(states[-1])
            actions.append(action)
            state, reward, isDone, info = env.step(action)
            rewards.append(reward)
            states.append(state)
            if isDone:
                self.Q[state] = np.ones_like(self.Q[state]) * reward
                break
                
        self.cutRepetitives(states, actions, rewards)
                
        if rewards[-1] > 0: 
            rew = rewards[-1]
            for i in reversed(range(0, len(states) - 1)):
                rew = rew * 0.9 + rewards[i] * 0.1
                self.update(states[i], actions[i], states[i + 1], rew)
        #else:
            #for i in reversed(range(0, len(states) - 1)):
            #    self.update(states[i], actions[i], states[i + 1], rewards[i] * 0.3)
        #if rewards[-1] > 0:
        #    print(rewards)
        return rewards[-1]
            
    

In [393]:
ql = QLearner(env)

In [394]:
rewsum = 0.0
rewc = 0
while True:
    #random.seed(0)
    rew = ql.play()
    rewsum += rew
    rewc += 1
    if random.random() < 0.001:
        ql.epsilon *= 0.98
        print((ql.Q > 0).sum(), rew, rewsum / rewc)

531 20 3.420503909643788
878 -1 8.99304731219264
888 -1 9.064564564564565
891 -1 9.084247258225323
1128 20 10.796714579055442
1202 20 11.329773869346734
990 20 12.62992506647329
1007 20 12.641810918774967
1065 20 12.774772727272728
1187 -1 13.120320495556154
1200 20 13.1824686940966
1106 20 13.404633339139522
1208 20 13.590864440078585
1275 20 13.731676428684501
1280 20 13.735480130705175
1284 20 13.738072669826224
1337 20 14.201183932346723
1355 20 14.276772028667303
1290 20 14.28014857281279
1262 20 14.270511583011583
1078 20 14.416917509385463
1045 20 14.397596295931436
1069 -1 14.41974223784417
1162 20 14.527901435826697
1173 20 14.542151796276695
1282 -1 14.841578011682705
1195 20 14.837607544672402
1281 -10 15.025651455446804
1366 20 15.110360804416404
1353 20 15.13565947589402
1281 20 15.20242691671263
1271 20 15.207571712192784
1252 20 15.20608269734594
1296 20 15.24742041712404
1286 20 15.249062068057658
1304 20 15.26857404540201
1311 20 15.38581465929406
1331 20 15.4065965839

1322 20 18.716429043527523
1322 20 18.719572383912592
1323 20 18.722624353605557
1323 20 18.72504032812715
1323 20 18.72592058032409
1323 20 18.72724861202568
1323 20 18.7348635235732
1323 20 18.73502831516986
1323 20 18.735365611667824
1324 20 18.738607792625487
1324 20 18.747139535383756
1324 20 18.7489226220236
1325 20 18.755476902422654
1325 20 18.759636174541342
1326 20 18.76548142229325
1326 20 18.766426028810006
1326 20 18.770145545441114
1328 20 18.77616051416381
1328 20 18.776755108100083
1328 20 18.78470422081985
1328 20 18.78885777660323
1328 20 18.79611100288664
1328 20 18.796423618283217
1330 20 18.812216799263823
1330 20 18.815297613342278
1331 20 18.819125374560297
1331 20 18.823281644428416
1331 20 18.834455404803
1331 20 18.83517437543551
1331 20 18.83555015030961
1331 20 18.83654328875907
1331 20 18.837246428042008
1331 20 18.83875265841597
1332 20 18.843162725266748
1334 20 18.855866949792002
1334 20 18.85621025002524
1334 20 18.85679691244356
1334 20 18.858821554473

1339 20 19.344528945694293
1339 20 19.345425930768503
1339 20 19.345944282272374
1339 20 19.346022907590665
1339 20 19.347015104521756
1339 20 19.347239709404363
1339 20 19.34781774733662
1339 20 19.348217808654645
1339 20 19.348967641051825
1339 20 19.349195057254374
1339 20 19.350162134957444
1339 20 19.35186744270274
1339 20 19.351868472260833
1339 20 19.354016193332047
1339 20 19.354203301355966
1339 20 19.354334110319357
1339 20 19.35481703093918
1339 20 19.355688132938127
1339 20 19.355844783245
1339 20 19.35640362735247
1339 20 19.35704256757502
1339 20 19.358565818474514
1339 20 19.361141948100954
1339 20 19.361301956880073
1339 20 19.36426357784914
1339 20 19.365254565565394
1339 20 19.36593420255171
1339 20 19.36666252735824
1339 20 19.367203748249405
1339 20 19.369076612622642
1339 20 19.370504596073204
1339 20 19.370612380895086
1339 20 19.37099176152831
1339 20 19.37203439524364
1339 20 19.372531741344947
1339 20 19.37422355961682
1339 20 19.374545872614394
1339 20 19.3751

1339 20 19.55630625487311
1339 20 19.5567983332754
1339 20 19.557213892554522
1339 20 19.557270586703414
1339 20 19.557540405486296
1339 20 19.557599894174846
1339 20 19.558719988925144
1339 20 19.558756257482212
1339 20 19.55925672675894
1339 20 19.56080855616815
1339 20 19.56081422909585
1339 20 19.56242083915589
1339 20 19.562422716289824
1339 20 19.56242600123546
1339 20 19.56245180980651
1339 20 19.562558765603374
1339 20 19.563183036708406
1339 20 19.563399452343273
1339 20 19.5634214091126
1339 20 19.5636235878448
1339 20 19.564056730374727
1339 20 19.56435276732595
1339 20 19.564514579686634
1339 20 19.564657692685238
1339 20 19.564852231170153
1339 20 19.56539360789932
1339 20 19.566462656593668
1339 20 19.566644079173766
1339 20 19.56733109301085
1339 20 19.567343480593735
1339 20 19.568886362411654
1339 20 19.569427750416573
1339 20 19.56995236975991
1339 20 19.570782514659133
1339 20 19.571640943321153
1339 20 19.571931263920902
1339 20 19.57211217799479
1339 20 19.57279181

1339 20 19.667555055095836
1339 20 19.667626820551963
1339 20 19.667684751933344
1339 20 19.667877350312207
1339 20 19.66813101185412
1339 20 19.668254595480622
1339 20 19.6692016133936
1339 20 19.669208854499406
1339 20 19.669216899800727
1339 20 19.66929813541257
1339 20 19.66955016740556
1339 20 19.669618934025813
1339 20 19.669938309855052
1339 20 19.670114703692235
1339 20 19.670130705902423
1339 20 19.670154439655033
1339 20 19.67019176673715
1339 20 19.670544356987595
1339 20 19.670618294543605
1339 20 19.670768461100867
1339 20 19.671165939846535
1339 20 19.671358231327915
1339 20 19.671961189933118
1339 20 19.672339219232338
1339 20 19.67265257198972
1339 20 19.672898469974708
1339 20 19.67300228168696
1339 20 19.673192432119837
1339 20 19.673286639596363
1339 20 19.673398309738296
1339 20 19.67346339363101
1339 20 19.673749559218276
1339 20 19.67387499000879
1339 20 19.67392007979074
1339 20 19.674105520433553
1339 20 19.67414039719867
1339 20 19.67422781656826
1339 20 19.674

KeyboardInterrupt: 

# Отчет

Просто каноничный Q-learning работал плохо, пришлось немного допилить: во-первых постепенно понижать $\varepsilon$, во-вторых отбросить все неудачные пути и в третьих вырезать петли из удачных путей (видимо обновления работали не очень коректно, когда эффективная длина пути из 160 шагов была 10 шагов).

Проверка удачности обучения:

In [395]:
n = 1000
success = 0
for i in range(1000):
    if ql.play() == 20:
        success += 1
print(success / n)

1.0


Т.е. на выборке из 1000 игр accuracy - 100%