<a href="https://colab.research.google.com/github/takagiyuusuke/Q_TETRIS_lv1/blob/main/Q_miniTETRIS_O%E3%83%9F%E3%83%8E%E7%89%88.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ★ 最初に実行すること

## 0. インポート・乱数シードの設定

In [118]:
import random

import torch
import numpy as np


random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7b96e80784b0>

## 1. miniTETRISのロジック
横6マスのテトリスを定義する

In [128]:
import copy

class Tetris:
  #各種ブロックの定義
  block1 = [[-1, 2, 1], [0, 2, 1], [-1, 3, 1], [0, 3, 1], [-0.5, 2.5]] #O
  block2 = [[0, 2, 2], [0, 3, 2], [0, 4, 2], [-1, 3, 2], [0, 3]] #T
  block3 = [[-1, 1, 3], [-1, 2, 3], [-1, 3, 3], [-1, 4, 3], [-0.5, 2.5]] #I
  block4 = [[0, 2, 4], [0, 3, 4], [0, 4, 4], [-1, 2, 4], [0, 3]] #J
  block5 = [[-1, 4, 5], [0, 2, 5], [0, 3, 5], [0, 4, 5], [0, 3]] #L
  block6 = [[-1, 2, 6], [-1, 3, 6], [0, 3, 6], [0, 4, 6], [0, 3]] #Z
  block7 = [[0, 2, 7], [0, 3, 7], [-1, 3, 7], [-1, 4, 7], [0, 3]] #S
  blocks =[block1,block2,block3,block4,block5,block6,block7]

  #ブロックの回転の定義
  roundblock=[[[[0,-1],[-1,0],[0,1]],[[-1,0],[0,1],[1,0]],[[0,1],[1,0],[0,-1]],[[1,0],[0,-1],[-1,0]]],
  [[[-0.5,-1.5],[-0.5,-0.5],[-0.5,0.5],[-0.5,1.5]],[[-1.5,0.5],[-0.5,0.5],[0.5,0.5],[1.5,0.5]],[[0.5,-1.5],[0.5,-0.5],[0.5,0.5],[0.5,1.5]],[[-1.5,-0.5],[-0.5,-0.5],[0.5,-0.5],[1.5,-0.5]]],
  [[[-1,-1],[0,-1],[0,1]],[[-1,1],[-1,0],[1,0]],[[0,-1],[0,1],[1,1]],[[-1,0],[1,0],[1,-1]]],
  [[[-1,1],[0,1],[0,-1]],[[1,0],[-1,0],[1,1]],[[0,-1],[0,1],[1,-1]],[[1,0],[-1,0],[-1,-1]]],
  [[[-1,-1],[-1,0],[0,1]],[[-1,1],[0,1],[1,0]],[[0,-1],[1,0],[1,1]],[[1,-1],[0,-1],[-1,0]]],
  [[[0,-1],[-1,0],[-1,1]],[[-1,0],[0,1],[1,1]],[[1,-1],[1,0],[0,1]],[[-1,-1],[0,-1],[1,0]]]]

  #ブロックを回転する際に移動する座標を定義(詳しくはSRS)
  SRS =[[[0,0],[0,1],[-1,1],[2,0],[2,1]],[[0,0],[0,-1],[-1,-1],[2,0],[2,-1]],[[0,0],[0,-1],[1,-1],[-2,0],[-2,-1]],[[0,0],[0,1],[1,1],[-2,0],[-2,1]]]
  SRSofI =[[[0,0],[0,-1],[0,2],[-2,-1],[1,2]],[[0,0],[0,-2],[0,1],[1,-2],[-2,1]],[[0,0],[0,1],[0,-2],[1,2],[-1,-2]],[[0,0],[0,2],[0,-1],[-1,2],[2,-1]]]

  '''
  盤面の情報などの情報の初期化を行う
  display:画面を表示するか否か
  '''
  def __init__(self, display: bool):
    self.display = display
    self.reset_game()

  '''
  ゲームをリセットする
  '''
  def reset_game(self):
    self.location=[]  #現在の盤面
    self.moveblock=[] #現在落下中のブロック
    self.supports=[]  #落下後の場所を示す補助ブロック
    self.hold=[]      #ホールド中のブロック
    self.angle = 0    #落下中のブロックの角度(0,1,2,3)
    self.num=1        #落下中のブロックが何番目か
    self.lines = 0    #消したラインの数
    self.score = 0    #スコア
    self.level = 0    #レベル
    self.canhold = True  #ホールドできるか
    self.turn_right = [] #右回転した場合のブロック配置
    self.turn_left = []  #左回転した場合のブロック配置
    self.before_action = "" #前回の行動
    self.before_before_action = "" #前々回の行動
    # self.nex =[random.choice(Tetris.blocks) for i in range(20)] #後におちるブロック
    self.nex =[Tetris.block1 for i in range(20)]


    self.setblock(copy.deepcopy(self.nex[0])) #最初のブロックをセットする

  '''
  ゲームの盤面を表示する
  '''
  def printboad(self, display_force = False):
      if display_force:
        pass
      elif not self.display:
        return
      boad = "\n"
      boad += ("⏹"*8+"     NEXT "+str(self.num+1) +"\n")
      for i in range(22):
          boad += ("⏹")
          for j in range(6):
              ad = [l for l in self.location + self.moveblock if l[0] ==i and l[1]==j]
              if ad:
                  if ad[0][2] ==0:
                      boad += ("💥")
                  else:
                      boad += (chr(128996+ad[0][2]))
              else:
                  ae = [m for m in self.supports if m[0] ==i and m[1]==j]
                  if ae:
                      boad += ("⬛")
                  else:
                      boad += ("⬜")
          boad += ("⏹  ")
          if i == 0 or i == 1:
              for j in range(1,5):
                  af = [l for l in self.nex[self.num%20] if l[0] ==i-1 and l[1]==j]
                  if af:
                      boad += (chr(128996+af[0][2]))
                  else:
                      boad += ("⬜")
          elif i == 3:
              boad += (" 👆👆👆 "+str(self.num+2))
          elif i == 4 or i == 5:
              for j in range(1,5):
                  af = [l for l in self.nex[(self.num+1)%20] if l[0] ==i-5 and l[1]==j]
                  if af:
                      boad += (chr(128996+af[0][2]))
                  else:
                      boad += ("⬜")
          elif i == 7:
              boad += (" 👆👆👆 "+str(self.num+3))
          elif i == 8 or i == 9:
              for j in range(1,5):
                  af = [l for l in self.nex[(self.num+2)%20] if l[0] ==i-9 and l[1]==j]
                  if af:
                      boad += (chr(128996+af[0][2]))
                  else:
                      boad += ("⬜")
          elif i == 12:
              boad += (" LEVEL =>"+str(self.level))
          elif i ==14:
              boad += (" LINES =>"+str(self.lines))
          elif i == 16:
              boad += (" SCORE =>"+str(self.score))
          elif i == 18:
              boad += (" -HOLD-")
          elif i == 19 or i == 20:
              for j in range(1,5):
                  af = [l for l in self.hold if l[0] ==i-20 and l[1]==j]
                  if af:
                      boad += (chr(128996+af[0][2]))
                  else:
                      boad += ("⬜")
          elif i == 21:
              boad += (" ------")
          boad += ("\n")
      boad += ("⏹"*8)
      print(boad)

  '''
  引数に指定したブロックを盤面に追加する
  '''
  def setblock(self, block):
      self.moveblock.extend(block)
      self.angle = 0
      self.supportdisplay()
      if self.num%20==13:
          # self.nex =[(random.choice(Tetris.blocks) if i < 10 else self.nex[i]) for i in range(20)]
          self.nex = [(Tetris.block1 if i < 10 else self.nex[i]) for i in range(20)]
      elif self.num%20==3:
          # self.nex =[(random.choice(Tetris.blocks) if i >= 10 else self.nex[i]) for i in range(20)]
          self.nex = [(Tetris.block1 if i >= 10 else self.nex[i]) for i in range(20)]

  '''
  ブロックが落下しなくなったか判定する
  '''
  def stopfall(self):
      for i in self.moveblock:
          aa = [j for j in self.location if (j[0] ==i[0]+1 and j[1]==i[1])]
          if len(aa)>0 or i[0] >=21:
              return True
      return False

  '''
  一列そろっている箇所があれば消去する
  '''
  def removeline(self):
      lines_first = self.lines
      L = []
      max_colum = 0
      for i in range(22):
          a = [False for l in range(6)]
          for j in range(6):
              ab = [k for k in self.location if k[0]==i and k[1]==j]
              if len(ab) > 0:
                  a[j] = True
          if all(a):
              self.location = [l for l in self.location if l[0] != i]
              self.location.extend([[i,k,0] for k in range(6)])
              L.append(i)
              self.lines += 1
              if self.lines % 10 == 0:
                  self.level += 1
      if L:
              self.location = [l for l in self.location if l[0] not in L]
              self.supports = []
              for j in L:
                  for m in self.location:
                      if m[0] < j:
                          m[0] += 1
      lines_final = self.lines
      self.score += (lines_final-lines_first)**2*100*(self.level+1)

  '''
  このまま落下したときに辿り着く地点に補助ブロックを表示する
  '''
  def supportdisplay(self):
      if not self.display:
        return

      z = copy.deepcopy(self.moveblock)
      while True:
          d = True
          for i in z:
              aa = [j for j in self.location if (j[0] ==i[0]+1 and j[1]==i[1])]
              if len(aa)>0 or i[0] >=21:
                  d = False
                  break
          if not d:
              self.supports = [[z[0][0],z[0][1],8],[z[1][0],z[1][1],8],[z[2][0],z[2][1],8],[z[3][0],z[3][1],8]]
              break
          else:
              for i in z:
                  i[0] +=1

  '''
  ブロックを右に移動させる
  '''
  def right(self):
      for j in self.moveblock:
          j[1] +=1
      self.supportdisplay()
      self.printboad()

  '''
  ブロックを左に移動させる
  '''
  def left(self):
      for j in self.moveblock:
          j[1] -=1
      self.supportdisplay()
      self.printboad()

  '''
  ブロックを下に動かせるか確かめたのちに下に移動させる
  '''
  def down(self):
      if self.stopfall():
            if self.gameover():
                return
            self.moveblock = [x for x in self.moveblock if len(x) != 2]
            self.location.extend(self.moveblock)
            self.location
            self.moveblock = []
            self.removeline()
            self.setblock(copy.deepcopy(self.nex[self.num%20]))
            self.canhold = True
            self.num += 1
      else:
          for j in self.moveblock:
              j[0] +=1
          self.printboad()

  '''
  ブロックを回転させられるか確認する。
  - direction:右か左か
  '''
  def turn(self, direction):
      col = self.moveblock[0][2]
      angle_before = self.angle
      if direction == "left":
          self.angle = self.angle-1 if self.angle >0 else 3
      elif direction == "right":
          self.angle = self.angle+1 if self.angle <3 else 0
      if col > 1:
          e = Tetris.roundblock[col-2][self.angle]
          if col != 3:
              SRSn = (3 if angle_before == 1 else 2)if angle_before%2==1 else(1 if self.angle == 1 else 0)
              n = 0
              while True:
                  m = [[self.moveblock[4][0]+e[0][0]+Tetris.SRS[SRSn][n][0],self.moveblock[4][1]+e[0][1]+Tetris.SRS[SRSn][n][1]],
                                  [self.moveblock[4][0]+e[1][0]+Tetris.SRS[SRSn][n][0],self.moveblock[4][1]+e[1][1]+Tetris.SRS[SRSn][n][1]],
                                  [self.moveblock[4][0]+e[2][0]+Tetris.SRS[SRSn][n][0],self.moveblock[4][1]+e[2][1]+Tetris.SRS[SRSn][n][1]],
                                  [self.moveblock[4][0]+Tetris.SRS[SRSn][n][0],self.moveblock[4][1]+Tetris.SRS[SRSn][n][1]]]
                  ac = [j for j in self.location if (j[0] ==m[0][0] and j[1] ==m[0][1])or(j[0] ==m[1][0] and j[1] ==m[1][1])or
                      (j[0] ==m[2][0] and j[1] ==m[2][1])or(j[0] ==m[3][0] and j[1] ==m[3][1])]

                  if len(ac) ==0 and max(m[0][1],m[1][1],m[2][1],m[3][1]) <= 5 and min(m[0][1],m[1][1],m[2][1],m[3][1])>=0 and max(m[0][0],m[1][0],m[2][0],m[3][0])<=21:
                      return [[m[0][0],m[0][1],col],[m[1][0],m[1][1],col],[m[2][0],m[2][1],col],[m[3][0],m[3][1],col],m[3]]
                      break
                  elif n==4:
                      break
                  else:
                      n += 1
          elif col ==3:
              if angle_before == 0:
                  if self.angle == 1:
                      SRSn = 1
                  else:
                      SRSn = 0
              elif angle_before == 1:
                  if self.angle ==0:
                      SRSn = 3
                  else:
                      SRSn = 0
              elif angle_before == 2:
                  if self.angle == 1:
                      SRSn = 2
                  else:
                      SRSn = 3
              else:
                  if self.angle ==0:
                      SRSn = 2
                  else:
                      SRSn = 1
              n = 0
              while True:
                  m = [[round(self.moveblock[4][0]+e[0][0]+Tetris.SRSofI[SRSn][n][0]),round(self.moveblock[4][1]+e[0][1]+Tetris.SRSofI[SRSn][n][1])],
                                  [round(self.moveblock[4][0]+e[1][0]+Tetris.SRSofI[SRSn][n][0]),round(self.moveblock[4][1]+e[1][1]+Tetris.SRSofI[SRSn][n][1])],
                                  [round(self.moveblock[4][0]+e[2][0]+Tetris.SRSofI[SRSn][n][0]),round(self.moveblock[4][1]+e[2][1]+Tetris.SRSofI[SRSn][n][1])],
                                  [round(self.moveblock[4][0]+e[3][0]+Tetris.SRSofI[SRSn][n][0]),round(self.moveblock[4][1]+e[3][1]+Tetris.SRSofI[SRSn][n][1])]]
                  ac = [j for j in self.location if (j[0] ==m[0][0] and j[1] ==m[0][1])or(j[0] ==m[1][0] and j[1] ==m[1][1])or
                      (j[0] ==m[2][0] and j[1] ==m[2][1])or(j[0] ==m[3][0] and j[1] ==m[3][1])]

                  if len(ac) ==0 and max(m[0][1],m[1][1],m[2][1],m[3][1])<=5 and min(m[0][1],m[1][1],m[2][1],m[3][1])>=0 and max(m[0][0],m[1][0],m[2][0],m[3][0])<=21:
                      return [[m[0][0],m[0][1],col],[m[1][0],m[1][1],col],[m[2][0],m[2][1],col],[m[3][0],m[3][1],col],[self.moveblock[4][0]+Tetris.SRSofI[SRSn][n][0],self.moveblock[4][1]+Tetris.SRSofI[SRSn][n][1]]]
                      break
                  elif n==4:
                      break
                  else:
                      n += 1
      return []

  def turnright():
      self.moveblock = self.turn_right
      self.supportdisplay()
      self.printboad()

  def turnleft():
      self.moveblock = self.turn_left
      self.supportdisplay()
      self.printboad()

  '''
  ブロックをホールドする
  '''
  def holding(self):
      if not self.hold:
          self.hold =Tetris.blocks[self.moveblock[0][2]-1][:]
          self.moveblock = []
          self.setblock(copy.deepcopy(self.nex[self.num%20]))
          self.canhold = True
          self.num += 1
      else:
          if self.canhold:
              self.hold , self.moveblock = copy.deepcopy(Tetris.blocks[self.moveblock[0][2]-1]) , copy.deepcopy(self.hold)
              self.canhold = False
              self.supportdisplay()
          else:
              return
      self.printboad()

  '''
  ゲームオーバーかどうかの判定を行う
  '''
  def gameover(self):
    try:
      min_moveblock = min(self.moveblock[i][0] for i in range(len(self.moveblock)))
    except:
      min_moveblock = 22
    try:
      min_location = min(self.location[i][0] for i in range(len(self.location)))
    except:
      min_location = 22

    return min_moveblock < 2 and min_location < 3

  '''
  盤面と落下中のブロック、ホールドブロックの形状を返す
  '''
  def get_state(self):
    # 形状のみを抽出して返す関数。
    def transform_matrix(matrix):
      result = [0 for i in range (6)]
      if len(matrix) == 0:
        return result
      maxe0 = min(matrix[i][0] for i in range(len(matrix)-1))
      for e in matrix:
          if e[0] == round(maxe0):
              result[round(e[1])] = 1
      return result
    location = tuple(transform_matrix(self.location))
    moveblock = self.moveblock[0][2], self.angle, int(self.moveblock[4][1])
    return location, moveblock

  '''
  すべての可能な行動を返す
  - 状態のループを防止するために以下の行動を追加で禁止
  ・右移動→左移動→右移動
  ・左移動→右移動→左移動
  ・右回転→左回転→右回転
  ・左回転→右回転→左回転
  '''
  def get_possible_actions(self):
    possible_actions = ["down"]
    right = True
    for i in self.moveblock:
        aa = [j for j in self.location if (j[1] ==i[1]+1 and j[0]==i[0])]
        if len(aa) > 0 or i[1] >= 5:
            right = False
            break
    if (right):
      if not (self.before_action == "left"  and self.before_before_action == "right"):
        possible_actions.append("right")
    left = True
    for i in self.moveblock:
        aa = [j for j in self.location if (j[1] ==i[1]-1 and j[0]==i[0])]
        if len(aa) > 0 or i[1] <=0:
            left = False
            break
    if (left):
      if not (self.before_action == "right" and self.before_before_action == "left"):
        possible_actions.append("left")
    self.turn_left = self.turn("left")
    self.turn_right = self.turn("right")
    if (self.turn_left):
      if not (self.before_action == "turn-right" and self.before_before_action == "turn-left"):
        possible_actions.append("turn-left")
    if (self.turn_right):
      if not (self.before_action == "turn-left" and self.before_before_action == "turn-right"):
        possible_actions.append("turn-right")
    if self.canhold:
      possible_actions.append("hold")
    return possible_actions

  '''
  スコアを取得する
  '''
  def get_score(self):
    return self.score

  '''
  盤面の変更を行う
  '''
  def modify_board(self, action):
    self.score -= 1
    self.before_before_action = self.before_action
    self.before_action = action
    if action == "right":
        self.right()
    elif action == "left":
        self.left()
    elif action == "down":
        self.down()
    elif action == "turn-left":
        self.turnleft()
        # self.score -= 10
    elif action == "turn-right":
        self.turnright()
        # self.score -= 10
    elif action == "hold":
        self.holding()


## 2. Agentの定義

In [120]:
class Agent:
    def __init__(self):
        self.frozen = False

    def train(self):
        self.frozen = False

    def eval(self):
        self.frozen = True

    def _observe(self, tetris: Tetris ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        return tetris.get_state()

    def action(self, tetris: Tetris):
        pass

    def update(self, tetris: Tetris, state, action, reward1, new_state):
        pass

### 2-1. Qエージェントの定義

In [138]:
class QAgent(Agent):
    def __init__(self, lr: float, eps = 0.3):
        super().__init__()
        self.q_table = {}  # Qテーブルの初期化
        self.discount_factor = 0.85
        self.learning_rate = lr
        self.epsilon = eps

    def action(self, tetris: Tetris) :
        board, block = self._observe(tetris)
        possible_actions = tetris.get_possible_actions()
        if not self.frozen and random.random() < self.epsilon:
            return random.choice(possible_actions)
        else:
            best_action, _ = self._get_the_best(board, block, possible_actions)
            return best_action

    def _get_the_best(self, board, block, possible_moves):
        best_move = None
        best_q_value = -float('inf')
        for move in possible_moves:
            q_value = self.q_table.get((board,block,move), 0)
            if q_value > best_q_value:
                best_q_value = q_value
                best_move = move
        return best_move, best_q_value

    #stateとnext_stateはboardとblockを渡す
    def update(self, tetris, state, action, reward, next_state):
        if self.frozen:
            return None

        old_value = self.q_table.get((state[0], state[1], action), 0)
        possible_moves = tetris.get_possible_actions()

        assert next_state is not None
        next_max = max([self.q_table.get((next_state[0], next_state[1], next_action), 0) for next_action in possible_moves])

        new_value = old_value + self.learning_rate * (reward + self.discount_factor * next_max - old_value)
        self.q_table[(state[0], state[1], action)] = new_value

### 2-2. 人間エージェントの定義

In [122]:
class HumanAgent(Agent):
  def action(self, tetris: Tetris):
        valid_moves = tetris.get_possible_actions()
        while True:
            user_input = input("Enter your move: ")
            return user_input

## 3. Env.

In [123]:
import tqdm
import plotly.graph_objects as go

class Env:
    def __init__(self, agent: Agent, tetris: Tetris) -> None:
        self.agent = agent
        self.tetris = tetris

    def _get_reward(self) -> int:
        score = self.tetris.get_score()
        return score

    def train(self, episodes: int, visualize=False) :  # Nエピソード実行
        record = []
        for i in tqdm.tqdm(range(episodes)):
            score = self.execute(train=True, visualize=visualize)
            record.append(self._get_reward())
        print("Report:")
        return record

    def execute(self, train=False, visualize=True):  # 1 episode
        self.tetris.reset_game()

        while not self.tetris.gameover():
            state = self.tetris.get_state()
            while True:
                action = self.agent.action(self.tetris)
                before_point = self.tetris.get_score()
                self.tetris.modify_board(action)
                after_point = self.tetris.get_score()
                reward = after_point - before_point
                if train:
                    next_state = self.tetris.get_state()

                    if action is not None:
                        self.agent.update(self.tetris, state, action, reward, next_state)
                break

        # game end
        score = self.tetris.get_score()
        if random.random() < 0.01:
          self.tetris.printboad(display_force = True)
        return score

# 4. Q学習
学習中、現状確認のために1%の確率でゲームオーバー時の盤面が出力される

In [141]:
q_agent = QAgent(lr=0.1, eps=0.1)
q_agent.epsilon = 0.3 #εはこのくらい大きめの方が経験的に良い
tetris = Tetris(False)
env = Env(q_agent, tetris)
record = env.train(1000)
fig = go.Figure(data=go.Scatter(y=record))
fig.show()

 28%|██▊       | 276/1000 [00:59<01:58,  6.12it/s]


⏹⏹⏹⏹⏹⏹⏹⏹     NEXT 41
⏹⬜⬜🟥🟥⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜⬜⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹🟥🟥⬜⬜⬜⬜⏹  
⏹🟥🟥⬜⬜⬜⬜⏹   👆👆👆 42
⏹⬜🟥🟥⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜🟥🟥⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹🟥🟥⬜⬜⬜⬜⏹  
⏹🟥🟥⬜⬜⬜⬜⏹   👆👆👆 43
⏹🟥🟥⬜⬜🟥🟥⏹  ⬜🟥🟥⬜
⏹🟥🟥⬜⬜🟥🟥⏹  ⬜🟥🟥⬜
⏹⬜🟥🟥⬜🟥🟥⏹  
⏹⬜🟥🟥⬜🟥🟥⏹  
⏹⬜⬜🟥🟥🟥🟥⏹   LEVEL =>1
⏹⬜⬜🟥🟥🟥🟥⏹  
⏹⬜🟥🟥🟥🟥⬜⏹   LINES =>14
⏹⬜🟥🟥🟥🟥⬜⏹  
⏹🟥🟥⬜🟥🟥⬜⏹   SCORE =>2533
⏹🟥🟥⬜🟥🟥⬜⏹  
⏹⬜🟥🟥⬜🟥🟥⏹   -HOLD-
⏹⬜🟥🟥⬜🟥🟥⏹  ⬜🟥🟥⬜
⏹⬜⬜🟥🟥🟥🟥⏹  ⬜🟥🟥⬜
⏹⬜⬜🟥🟥🟥🟥⏹   ------
⏹⏹⏹⏹⏹⏹⏹⏹


 84%|████████▍ | 838/1000 [02:58<00:33,  4.79it/s]


⏹⏹⏹⏹⏹⏹⏹⏹     NEXT 35
⏹⬜⬜🟥🟥⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜⬜⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜🟥🟥⬜⬜⬜⏹  
⏹⬜🟥🟥⬜⬜⬜⏹   👆👆👆 36
⏹⬜⬜🟥🟥🟥🟥⏹  ⬜🟥🟥⬜
⏹⬜⬜🟥🟥🟥🟥⏹  ⬜🟥🟥⬜
⏹🟥🟥⬜🟥🟥⬜⏹  
⏹🟥🟥⬜🟥🟥⬜⏹   👆👆👆 37
⏹⬜🟥🟥⬜🟥🟥⏹  ⬜🟥🟥⬜
⏹⬜🟥🟥⬜🟥🟥⏹  ⬜🟥🟥⬜
⏹⬜⬜🟥🟥🟥🟥⏹  
⏹⬜⬜🟥🟥🟥🟥⏹  
⏹⬜⬜⬜🟥🟥⬜⏹   LEVEL =>1
⏹⬜⬜⬜🟥🟥⬜⏹  
⏹⬜⬜⬜⬜🟥🟥⏹   LINES =>10
⏹⬜⬜⬜⬜🟥🟥⏹  
⏹⬜⬜🟥🟥🟥🟥⏹   SCORE =>1628
⏹⬜⬜🟥🟥🟥🟥⏹  
⏹⬜🟥🟥⬜🟥🟥⏹   -HOLD-
⏹⬜🟥🟥⬜🟥🟥⏹  ⬜🟥🟥⬜
⏹⬜🟥🟥🟥🟥⬜⏹  ⬜🟥🟥⬜
⏹⬜🟥🟥🟥🟥⬜⏹   ------
⏹⏹⏹⏹⏹⏹⏹⏹


100%|██████████| 1000/1000 [03:35<00:00,  4.65it/s]

Report:





# 5. 学習が完了したQエージェントにプレイさせてみる
モデルが実際にどのようなプレイをするのか確かめる。

In [142]:
q_agent.eval()
env = Env(q_agent, Tetris(True))
env.execute()
q_agent.train()

[1;30;43mストリーミング出力は最後の 5000 行に切り捨てられました。[0m
⏹⏹⏹⏹⏹⏹⏹⏹     NEXT 252
⏹⬜⬜🟥🟥⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜🟥🟥⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜⬜⬜⬜⬜⏹  
⏹⬜⬜⬜⬜⬜⬜⏹   👆👆👆 253
⏹⬜⬜⬜⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜⬜⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜⬜⬜⬜⬜⏹  
⏹⬜⬜⬜⬜⬜⬜⏹   👆👆👆 254
⏹⬜⬜⬜⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜⬜⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜⬜⬜⬜⬜⏹  
⏹⬜⬜⬜⬜⬜⬜⏹  
⏹⬜⬜⬜⬜⬜⬜⏹   LEVEL =>16
⏹⬜⬜⬜⬜⬜⬜⏹  
⏹⬜⬜⬜⬜⬜⬜⏹   LINES =>166
⏹⬜⬜⬜⬜⬜⬜⏹  
⏹⬜⬜⬜⬜⬜⬜⏹   SCORE =>289250
⏹⬜⬜⬜⬜⬜⬜⏹  
⏹⬜⬜⬜⬜⬜⬜⏹   -HOLD-
⏹⬜⬜⬜⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜⬛⬛⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜⬛⬛⬜⬜⏹   ------
⏹⏹⏹⏹⏹⏹⏹⏹

⏹⏹⏹⏹⏹⏹⏹⏹     NEXT 252
⏹⬜🟥🟥⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜🟥🟥⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜⬜⬜⬜⬜⏹  
⏹⬜⬜⬜⬜⬜⬜⏹   👆👆👆 253
⏹⬜⬜⬜⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜⬜⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜⬜⬜⬜⬜⏹  
⏹⬜⬜⬜⬜⬜⬜⏹   👆👆👆 254
⏹⬜⬜⬜⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜⬜⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜⬜⬜⬜⬜⏹  
⏹⬜⬜⬜⬜⬜⬜⏹  
⏹⬜⬜⬜⬜⬜⬜⏹   LEVEL =>16
⏹⬜⬜⬜⬜⬜⬜⏹  
⏹⬜⬜⬜⬜⬜⬜⏹   LINES =>166
⏹⬜⬜⬜⬜⬜⬜⏹  
⏹⬜⬜⬜⬜⬜⬜⏹   SCORE =>289249
⏹⬜⬜⬜⬜⬜⬜⏹  
⏹⬜⬜⬜⬜⬜⬜⏹   -HOLD-
⏹⬜⬜⬜⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬛⬛⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬛⬛⬜⬜⬜⏹   ------
⏹⏹⏹⏹⏹⏹⏹⏹

⏹⏹⏹⏹⏹⏹⏹⏹     NEXT 252
⏹⬜⬜🟥🟥⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜🟥🟥⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜⬜⬜⬜⬜⏹  
⏹⬜⬜⬜⬜⬜⬜⏹   👆👆👆 253
⏹⬜⬜⬜⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜⬜⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜⬜⬜⬜⬜⏹  
⏹⬜⬜⬜⬜⬜⬜⏹   👆👆👆 254
⏹⬜⬜⬜⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜⬜⬜⬜⬜⏹  ⬜🟥🟥⬜
⏹⬜⬜⬜⬜⬜⬜⏹  
⏹⬜⬜⬜⬜⬜⬜⏹  
⏹⬜⬜⬜⬜⬜⬜⏹   L

KeyboardInterrupt: 

# 6. 追加で学習させたい場合
モデルをリセットせずにさらに学習させたい場合はここを実行  
epsilon, lr を適宜変更するとよい

In [134]:
# @title 追加学習の内容 { display-mode: "both" }
epsilon = 0.1 # @param {type:"number"}
lr = 0.1 # @param {type:"number"}
q_agent.train()
tetris = Tetris(False)
q_agent.epsilon = epsilon
q_agent.learning_rate = lr
env = Env(q_agent, tetris)
record = env.train(1000)
fig = go.Figure(data=go.Scatter(y=record))
fig.show()

100%|██████████| 1000/1000 [08:07<00:00,  2.05it/s]

Report:





# 付録

## 手動で遊んでみる
※ 非常にユーザビリティが悪いため、やらない方がよい(笑)


In [None]:
env = Env(HumanAgent(), Tetris(True))
env.execute()

## Qテーブルを見てみる
モデルが学習したQ値を見ることができる。ただし、モデルが重い場合は非常に時間がかかる可能性がある


In [None]:
print((q_agent.q_table))

## うまく学習できた場合のデータを用いてプレイ
学習にはランダム性があるので失敗することがある  
私が何回か試した中でうまくいった際のQテーブルを用いたプレイはこちらから

In [None]:
# @title うまくいったQテーブルを読み込む { display-mode: "form" }
good_q_table = {((0, 0, 0, 0, 0, 0), (1, 0, 2), 'down'): -1.9766366422002568, ((0, 0, 0, 0, 0, 0), (1, 0, 2), 'right'): -2.374465293128104, ((0, 0, 0, 0, 0, 0), (1, 0, 3), 'down'): -2.3585733792012817, ((0, 0, 0, 0, 0, 0), (1, 0, 3), 'right'): -2.373453134115031, ((0, 0, 0, 0, 0, 0), (1, 0, 4), 'down'): -2.3287459460552853, ((0, 0, 0, 0, 0, 0), (1, 0, 4), 'left'): -2.398511740146238, ((0, 0, 0, 0, 0, 0), (1, 0, 3), 'left'): -2.3938004859282445, ((0, 0, 0, 0, 0, 0), (1, 0, 2), 'left'): -2.374918922081176, ((0, 0, 0, 0, 0, 0), (1, 0, 1), 'hold'): -1.8219843930197728, ((0, 0, 0, 0, 0, 0), (1, 0, 2), 'hold'): -1.0620099055621157, ((0, 0, 0, 0, 0, 0), (1, 0, 1), 'down'): -2.8380311410955184, ((0, 0, 0, 0, 0, 0), (1, 0, 1), 'right'): -2.360054443240093, ((0, 0, 0, 0, 0, 0), (1, 0, 1), 'left'): -3.421025264696645, ((0, 0, 0, 0, 0, 0), (1, 0, 0), 'right'): -2.6263427635728487, ((0, 0, 0, 0, 0, 0), (1, 0, 0), 'down'): -3.67759875948391, ((0, 1, 1, 0, 0, 0), (1, 0, 2), 'down'): -2.7066482114197954, ((0, 1, 1, 0, 0, 0), (1, 0, 2), 'right'): -0.5757195479384891, ((0, 1, 1, 0, 0, 0), (1, 0, 3), 'down'): 0.5338972029796245, ((0, 1, 1, 0, 0, 0), (1, 0, 3), 'right'): -2.839267535019633, ((0, 1, 1, 0, 0, 0), (1, 0, 4), 'down'): -2.8837196397807614, ((0, 1, 1, 0, 0, 0), (1, 0, 4), 'left'): -2.253533883535133, ((0, 1, 1, 0, 0, 0), (1, 0, 3), 'left'): -2.339635847889714, ((0, 1, 1, 0, 0, 0), (1, 0, 2), 'left'): -2.8994772568959455, ((0, 1, 1, 0, 0, 0), (1, 0, 1), 'down'): -3.0638391698336114, ((0, 1, 1, 0, 0, 0), (1, 0, 1), 'right'): -2.5140408625527826, ((0, 1, 1, 0, 0, 0), (1, 0, 2), 'hold'): -2.721312651229387, ((0, 1, 1, 0, 0, 0), (1, 0, 1), 'left'): -3.397041823413185, ((0, 1, 1, 0, 0, 0), (1, 0, 0), 'down'): -3.767398126014126, ((0, 1, 1, 0, 0, 0), (1, 0, 0), 'right'): -3.142307732365744, ((0, 1, 1, 1, 1, 0), (1, 0, 2), 'down'): 7.562897547488827, ((0, 1, 1, 1, 1, 0), (1, 0, 2), 'right'): 8.38084715438736, ((0, 1, 1, 1, 1, 0), (1, 0, 3), 'down'): 9.754292791813292, ((0, 1, 1, 1, 1, 0), (1, 0, 3), 'right'): 9.716754509225416, ((0, 1, 1, 1, 1, 0), (1, 0, 4), 'down'): 13.258068426603035, ((0, 1, 1, 1, 1, 0), (1, 0, 4), 'left'): 9.510282166831562, ((0, 1, 1, 1, 1, 0), (1, 0, 3), 'left'): 8.71411298100424, ((0, 1, 1, 1, 1, 0), (1, 0, 2), 'left'): 3.6453937911238783, ((0, 1, 1, 1, 1, 0), (1, 0, 1), 'down'): -0.07892792050348496, ((0, 1, 1, 1, 1, 0), (1, 0, 1), 'right'): 10.300013954695421, ((0, 1, 1, 1, 1, 0), (1, 0, 2), 'hold'): 4.667590938807294, ((0, 1, 1, 1, 1, 0), (1, 0, 1), 'left'): -1.6091129850043295, ((0, 1, 1, 1, 1, 0), (1, 0, 0), 'down'): -1.704789793364415, ((0, 1, 1, 1, 1, 0), (1, 0, 0), 'right'): 0.41657922302820816, ((0, 0, 0, 0, 1, 1), (1, 0, 2), 'down'): 59.974341803143695, ((0, 0, 0, 0, 1, 1), (1, 0, 2), 'left'): 63.234874019866375, ((0, 0, 0, 0, 1, 1), (1, 0, 1), 'down'): 63.014641246109974, ((0, 0, 0, 0, 1, 1), (1, 0, 1), 'right'): 62.61350683307149, ((0, 0, 0, 0, 1, 1), (1, 0, 2), 'right'): 60.64439915175068, ((0, 0, 0, 0, 1, 1), (1, 0, 3), 'hold'): 152.59371591333522, ((0, 0, 0, 0, 1, 1), (1, 0, 3), 'down'): 18.40740664546503, ((0, 0, 0, 0, 1, 1), (1, 0, 3), 'left'): 58.98667514008283, ((0, 0, 0, 0, 1, 1), (1, 0, 1), 'left'): 74.9946055187938, ((0, 0, 0, 0, 1, 1), (1, 0, 0), 'down'): 173.2054066682874, ((0, 0, 0, 0, 1, 1), (1, 0, 0), 'right'): 59.607279481081555, ((0, 0, 0, 0, 1, 1), (1, 0, 3), 'right'): 16.42987718527775, ((0, 0, 0, 0, 1, 1), (1, 0, 4), 'down'): 8.40947629676452, ((0, 0, 0, 0, 1, 1), (1, 0, 4), 'left'): 13.495906962575924, ((0, 0, 1, 1, 1, 1), (1, 0, 2), 'down'): 11.581605409249855, ((0, 0, 1, 1, 1, 1), (1, 0, 2), 'right'): 10.76996686703045, ((0, 0, 1, 1, 1, 1), (1, 0, 3), 'down'): 11.352583653109926, ((0, 0, 1, 1, 1, 1), (1, 0, 3), 'right'): 11.191120693002198, ((0, 0, 1, 1, 1, 1), (1, 0, 4), 'down'): 8.41864693974621, ((0, 0, 1, 1, 1, 1), (1, 0, 4), 'left'): 9.887249845112322, ((0, 0, 1, 1, 1, 1), (1, 0, 3), 'left'): 12.884604248521782, ((0, 0, 1, 1, 1, 1), (1, 0, 2), 'left'): 29.41085442461389, ((0, 0, 1, 1, 1, 1), (1, 0, 1), 'down'): 20.82000723045544, ((0, 0, 1, 1, 1, 1), (1, 0, 1), 'right'): 17.129918067836137, ((0, 0, 1, 1, 1, 1), (1, 0, 2), 'hold'): 15.293428843052617, ((0, 0, 1, 1, 1, 1), (1, 0, 1), 'left'): 27.036824027677334, ((0, 0, 1, 1, 1, 1), (1, 0, 0), 'down'): 77.06857774679773, ((0, 0, 1, 1, 1, 1), (1, 0, 0), 'right'): 25.044401323927165, ((0, 0, 1, 1, 0, 0), (1, 0, 2), 'down'): -1.2846099545895693, ((0, 0, 1, 1, 0, 0), (1, 0, 2), 'right'): 3.0705655459046985, ((0, 0, 1, 1, 0, 0), (1, 0, 3), 'down'): 0.21564859656918683, ((0, 0, 1, 1, 0, 0), (1, 0, 3), 'right'): 1.1051353489333684, ((0, 0, 1, 1, 0, 0), (1, 0, 4), 'down'): 1.0515095273831068, ((0, 0, 1, 1, 0, 0), (1, 0, 4), 'left'): 0.9960197845953046, ((0, 0, 1, 1, 0, 0), (1, 0, 3), 'left'): -0.008544142493457807, ((0, 0, 1, 1, 0, 0), (1, 0, 2), 'left'): -1.27018248427728, ((0, 0, 1, 1, 0, 0), (1, 0, 1), 'down'): -1.2457648896721516, ((0, 0, 1, 1, 0, 0), (1, 0, 1), 'hold'): 0.9601878494279074, ((0, 0, 1, 1, 0, 0), (1, 0, 1), 'right'): -1.29398623458885, ((0, 0, 1, 1, 0, 0), (1, 0, 1), 'left'): -1.2455367995664675, ((0, 0, 1, 1, 0, 0), (1, 0, 0), 'down'): -1.3312431731026981, ((0, 0, 1, 1, 0, 0), (1, 0, 0), 'right'): -1.2342460315618127, ((0, 0, 0, 1, 1, 0), (1, 0, 2), 'down'): -1.7340083911130368, ((0, 0, 0, 1, 1, 0), (1, 0, 2), 'right'): -1.227771675379653, ((0, 0, 0, 1, 1, 0), (1, 0, 3), 'down'): -1.3064220236111812, ((0, 0, 0, 1, 1, 0), (1, 0, 3), 'right'): 3.1555510295639926, ((0, 0, 0, 1, 1, 0), (1, 0, 4), 'down'): 15.570258636770836, ((0, 0, 0, 1, 1, 0), (1, 0, 4), 'left'): -1.1345693452330827, ((0, 0, 0, 1, 1, 0), (1, 0, 3), 'left'): -1.6723599263361326, ((0, 0, 0, 1, 1, 0), (1, 0, 2), 'left'): -1.7647650509405597, ((0, 0, 0, 1, 1, 0), (1, 0, 1), 'down'): -2.355593974630897, ((0, 0, 0, 1, 1, 0), (1, 0, 1), 'right'): -1.6258463250491073, ((0, 0, 0, 1, 1, 0), (1, 0, 2), 'hold'): -1.6419908571486455, ((0, 0, 0, 1, 1, 0), (1, 0, 1), 'left'): -2.9356233175069795, ((0, 0, 0, 1, 1, 0), (1, 0, 0), 'right'): -2.3722453569579676, ((0, 0, 0, 1, 1, 0), (1, 0, 0), 'down'): -3.0764547616794573, ((0, 0, 0, 0, 1, 1), (1, 0, 2), 'hold'): 52.025824449747546, ((0, 0, 1, 1, 0, 0), (1, 0, 2), 'hold'): -1.1565825272223775, ((0, 0, 1, 1, 0, 0), (1, 0, 0), 'hold'): 2.3544721496300722, ((1, 1, 1, 1, 0, 0), (1, 0, 2), 'hold'): 43.58451199859094, ((1, 1, 1, 1, 0, 0), (1, 0, 2), 'down'): 41.73098735115, ((1, 1, 1, 1, 0, 0), (1, 0, 2), 'right'): 44.20376358046829, ((1, 1, 1, 1, 0, 0), (1, 0, 3), 'down'): 45.62727306063076, ((1, 1, 1, 1, 0, 0), (1, 0, 3), 'right'): 71.31245361173191, ((1, 1, 1, 1, 0, 0), (1, 0, 4), 'left'): 45.49341290234275, ((1, 1, 1, 1, 0, 0), (1, 0, 4), 'down'): 190.94707473400516, ((1, 1, 1, 1, 0, 0), (1, 0, 3), 'left'): 46.446968222249545, ((1, 1, 1, 1, 0, 0), (1, 0, 2), 'left'): 39.08664046188055, ((1, 1, 1, 1, 0, 0), (1, 0, 1), 'down'): 25.438466155896823, ((1, 1, 1, 1, 0, 0), (1, 0, 1), 'right'): 41.1993901165634, ((1, 1, 1, 1, 0, 0), (1, 0, 1), 'left'): 21.862310262798328, ((1, 1, 1, 1, 0, 0), (1, 0, 0), 'down'): 11.99314177435555, ((1, 1, 1, 1, 0, 0), (1, 0, 0), 'right'): 40.399949795628494, ((0, 1, 1, 0, 0, 0), (1, 0, 3), 'hold'): -2.416046904981502, ((1, 1, 0, 0, 0, 0), (1, 0, 2), 'left'): 64.88646021749591, ((1, 1, 0, 0, 0, 0), (1, 0, 1), 'down'): 60.54827290129268, ((1, 1, 0, 0, 0, 0), (1, 0, 1), 'right'): 65.80236846913593, ((1, 1, 0, 0, 0, 0), (1, 0, 2), 'down'): 66.1374759660832, ((1, 1, 0, 0, 0, 0), (1, 0, 2), 'right'): 63.91169196124565, ((1, 1, 0, 0, 0, 0), (1, 0, 3), 'down'): 65.22242548062808, ((1, 1, 0, 0, 0, 0), (1, 0, 3), 'hold'): 135.18533618108444, ((1, 1, 0, 0, 0, 0), (1, 0, 3), 'right'): 65.22911773258774, ((1, 1, 0, 0, 0, 0), (1, 0, 4), 'down'): 146.87674832902047, ((1, 1, 0, 0, 0, 0), (1, 0, 4), 'left'): 66.10603239238667, ((1, 1, 0, 0, 0, 0), (1, 0, 3), 'left'): 66.55802323781981, ((1, 1, 0, 0, 0, 0), (1, 0, 1), 'left'): 17.311300656317318, ((1, 1, 0, 0, 0, 0), (1, 0, 0), 'down'): -0.5208773447705222, ((1, 1, 0, 0, 0, 0), (1, 0, 0), 'hold'): 253.06769491075505, ((1, 1, 0, 0, 0, 0), (1, 0, 0), 'right'): 27.971189443221707, ((1, 1, 0, 0, 0, 0), (1, 0, 2), 'hold'): 89.45043917850496, ((1, 1, 1, 1, 0, 0), (1, 0, 3), 'hold'): 52.095136238764034, ((0, 0, 0, 0, 1, 1), (1, 0, 1), 'hold'): 105.80691881340726, ((0, 0, 1, 1, 1, 1), (1, 0, 3), 'hold'): 79.69970565306222, ((0, 0, 0, 1, 1, 0), (1, 0, 1), 'hold'): -1.2064017083420986, ((1, 1, 0, 1, 1, 0), (1, 0, 2), 'down'): 26.29202685005428, ((1, 1, 0, 1, 1, 0), (1, 0, 2), 'right'): 26.296214143926846, ((1, 1, 0, 1, 1, 0), (1, 0, 3), 'down'): 26.407766740365386, ((1, 1, 0, 1, 1, 0), (1, 0, 3), 'right'): 27.453972633572953, ((1, 1, 0, 1, 1, 0), (1, 0, 4), 'down'): 53.575496500905544, ((1, 1, 0, 1, 1, 0), (1, 0, 4), 'left'): 27.68218203609524, ((1, 1, 0, 1, 1, 0), (1, 0, 3), 'left'): 27.33142257186132, ((1, 1, 0, 1, 1, 0), (1, 0, 2), 'left'): 13.630213646938019, ((1, 1, 0, 1, 1, 0), (1, 0, 1), 'down'): 2.3180795217953105, ((1, 1, 0, 1, 1, 0), (1, 0, 1), 'right'): 23.778178215587044, ((1, 1, 0, 1, 1, 0), (1, 0, 2), 'hold'): 29.7557602983952, ((1, 1, 0, 1, 1, 0), (1, 0, 1), 'left'): 1.7443066763860728, ((1, 1, 0, 1, 1, 0), (1, 0, 0), 'down'): -0.9116703282970555, ((1, 1, 0, 1, 1, 0), (1, 0, 0), 'right'): 2.828908581650818, ((0, 1, 1, 0, 0, 0), (1, 0, 1), 'hold'): -2.0555333696490785, ((0, 1, 1, 0, 1, 1), (1, 0, 2), 'down'): 7.8199635188091, ((0, 1, 1, 0, 1, 1), (1, 0, 2), 'right'): 8.446945076612634, ((0, 1, 1, 0, 1, 1), (1, 0, 3), 'down'): 1.5006564495441161, ((0, 1, 1, 0, 1, 1), (1, 0, 3), 'right'): 2.635976948155168, ((0, 1, 1, 0, 1, 1), (1, 0, 4), 'down'): -0.2965087303890719, ((0, 1, 1, 0, 1, 1), (1, 0, 4), 'left'): -0.32597227544945734, ((0, 1, 1, 0, 1, 1), (1, 0, 3), 'left'): 8.86835782977489, ((0, 1, 1, 0, 1, 1), (1, 0, 2), 'left'): 8.216717615897295, ((0, 1, 1, 0, 1, 1), (1, 0, 1), 'down'): 7.831463103310142, ((0, 1, 1, 0, 1, 1), (1, 0, 1), 'hold'): 13.763340413562505, ((0, 1, 1, 0, 1, 1), (1, 0, 1), 'right'): 8.612803892121686, ((0, 1, 1, 0, 1, 1), (1, 0, 1), 'left'): 8.332808310593071, ((0, 1, 1, 0, 1, 1), (1, 0, 0), 'down'): 14.545351854722009, ((0, 1, 1, 0, 1, 1), (1, 0, 0), 'right'): 8.450721219918831, ((1, 1, 1, 1, 0, 0), (1, 0, 1), 'hold'): 72.23889294055095, ((0, 0, 1, 1, 1, 1), (1, 0, 0), 'hold'): 38.95747740830821, ((0, 0, 0, 0, 1, 1), (1, 0, 0), 'hold'): 223.6388796978828, ((0, 1, 1, 0, 1, 1), (1, 0, 2), 'hold'): 12.12793565465541, ((0, 0, 0, 0, 1, 1), (1, 0, 4), 'hold'): 146.0850274313299, ((1, 1, 0, 0, 0, 0), (1, 0, 1), 'hold'): 71.12153515644901, ((1, 1, 0, 0, 1, 1), (1, 0, 2), 'down'): 198.6764127974222, ((1, 1, 0, 0, 1, 1), (1, 0, 2), 'left'): 196.2837012461664, ((1, 1, 0, 0, 1, 1), (1, 0, 1), 'down'): 199.35487057813015, ((1, 1, 0, 0, 1, 1), (1, 0, 1), 'right'): 198.96334202809095, ((1, 1, 0, 0, 1, 1), (1, 0, 2), 'right'): 191.88276486208582, ((1, 1, 0, 0, 1, 1), (1, 0, 3), 'down'): 178.74288855432115, ((1, 1, 0, 0, 1, 1), (1, 0, 3), 'left'): 195.66920154944302, ((1, 1, 0, 0, 1, 1), (1, 0, 2), 'hold'): 828.9535626475777, ((1, 1, 0, 0, 1, 1), (1, 0, 3), 'right'): 91.67857033918713, ((1, 1, 0, 0, 1, 1), (1, 0, 4), 'down'): 55.82297339007018, ((1, 1, 0, 0, 1, 1), (1, 0, 4), 'left'): 187.2691452876171, ((1, 1, 0, 0, 1, 1), (1, 0, 1), 'left'): 178.1497405311065, ((1, 1, 0, 0, 1, 1), (1, 0, 0), 'down'): 155.25651636851538, ((1, 1, 0, 0, 1, 1), (1, 0, 0), 'right'): 243.46606200584588, ((0, 0, 1, 1, 1, 1), (1, 0, 1), 'hold'): 36.465524748807255, ((0, 0, 1, 1, 0, 0), (1, 0, 3), 'hold'): 5.562700257521145, ((0, 0, 0, 1, 1, 0), (1, 0, 3), 'hold'): -0.252482381577117, ((0, 1, 1, 1, 1, 0), (1, 0, 1), 'hold'): 12.768087130669342, ((1, 1, 0, 0, 1, 1), (1, 0, 3), 'hold'): 506.24171365231405, ((0, 1, 1, 1, 1, 0), (1, 0, 0), 'hold'): 6.771755440966897, ((0, 1, 1, 1, 1, 0), (1, 0, 3), 'hold'): 11.8523168135685, ((0, 1, 1, 1, 1, 0), (1, 0, 4), 'hold'): 20.37645693135214, ((1, 1, 0, 1, 1, 0), (1, 0, 1), 'hold'): 35.06477269398806, ((1, 1, 0, 1, 1, 0), (1, 0, 3), 'hold'): 37.74127164678059, ((0, 0, 1, 1, 0, 0), (1, 0, 4), 'hold'): 42.695809426618744, ((0, 0, 0, 0, 0, 0), (1, 0, 4), 'hold'): 2.9762236420880726, ((0, 0, 0, 0, 0, 0), (1, 0, 3), 'hold'): -1.045624042911788, ((0, 1, 1, 0, 1, 1), (1, 0, 3), 'hold'): 25.063372551149985, ((0, 1, 1, 0, 1, 1), (1, 0, 4), 'hold'): 22.028066168899908, ((1, 1, 0, 1, 1, 0), (1, 0, 4), 'hold'): 47.820338099345264, ((0, 1, 1, 0, 0, 0), (1, 0, 4), 'hold'): -1.406965491638156, ((1, 1, 0, 0, 1, 1), (1, 0, 1), 'hold'): 1475.879300864999, ((1, 1, 0, 0, 1, 1), (1, 0, 0), 'hold'): 1848.722890447066, ((0, 0, 0, 1, 1, 0), (1, 0, 0), 'hold'): -1.2272740860410951, ((1, 1, 1, 1, 0, 0), (1, 0, 4), 'hold'): 66.33075473904886, ((0, 1, 1, 0, 1, 1), (1, 0, 0), 'hold'): 20.1427448111547, ((1, 1, 0, 0, 1, 1), (1, 0, 4), 'hold'): 1413.3413016609172, ((0, 0, 0, 1, 1, 0), (1, 0, 4), 'hold'): 2.8918262128951353, ((0, 1, 1, 0, 0, 0), (1, 0, 0), 'hold'): -2.2192598543132243, ((1, 1, 0, 0, 0, 0), (1, 0, 4), 'hold'): 212.60599396831736, ((1, 1, 0, 1, 1, 0), (1, 0, 0), 'hold'): 3.793657464709744, ((0, 0, 0, 0, 0, 0), (1, 0, 0), 'hold'): -3.4748124103536724, ((0, 0, 1, 1, 1, 1), (1, 0, 4), 'hold'): 121.83104689424648}

In [None]:
q_agent = QAgent(lr = 0.1)
q_agent.q_table = good_q_table
q_agent.eval()
env = Env(q_agent, Tetris(True))
env.execute()