In [1]:
import warnings; warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import random
import gym
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

%matplotlib inline

In [2]:
class DeepSARSA:
    def __init__(self, num_states, num_actions):
        self.num_states = num_states
        self.num_actions = num_actions
        self.alpha = 0.001
        self.gamma = 0.99
        self.epsilon = 1.
        self.epsilon_decay = .99999
        self.epsilon_min = 0.2
        self.model = nn.Sequential(
            nn.Linear(self.num_states, 32),
            nn.ReLU(),
            nn.Linear(32,32),
            nn.ReLU(),
            nn.Linear(32, self.num_actions)
        )
        # optimizer = 최소의 cost로 결과를 찾아주는 것
        # 실제 결과와 모델이 예측한 결과를 기반으로 잘 줄일 수 있게 만들어주는 역할
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.alpha)
        
    def update(self, state, action, reward, next_state, next_action, done):
        self.decrease_epsilon()
        self.optimizer.zero_grad()
        
        q_value = self.model(state)[action]
        # 기존 Tensor에서 Gradinent 전파가 안되는 텐서 생성
        next_q_value = self.model(next_state)[next_action].detach()
        
        q_target = reward + (1 - int(done)) * self.gamma * next_q_value
        q_error = (q_target - q_value) ** 2
        
        q_error.backward()
        self.optimizer.step()
        
        return q_error.item()
    
    def decrease_epsilon(self):
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
            
    def act(self, state):
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
            
    def act(self,state):
        if np.random.rand() < self.epsilon:
            action = np.random.choice(self.num_actions)
        else:
            q_values = self.model(state)
            action = torch.argmax(q_values).item()
            
        return action
        

In [5]:
class FlatObsWrapper(gym.core.ObservationWrapper):
    """Fully observable gridworld returning a flat grid encoding."""

    def __init__(self, env):
        super().__init__(env)

        # Since the outer walls are always present, we remove left, right, top, bottom walls
        # from the observation space of the agent. There are 3 channels, but for simplicity
        # in this assignment, we will deal with flattened version of state.
        
        self.observation_space = spaces.Box(
            low=0,
            high=255,
            shape=((self.env.width-2) * (self.env.height-2) * 3,),  # number of cells
            dtype='uint8'
        )
        self.unwrapped.max_steps = max_env_steps

    def observation(self, obs):
        # this method is called in the step() function to get the observation
        # we provide code that gets the grid state and places the agent in it
        env = self.unwrapped
        full_grid = env.grid.encode()
        full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([
            OBJECT_TO_IDX['agent'],
            COLOR_TO_IDX['red'],
            env.agent_dir
        ])
        full_grid = full_grid[1:-1, 1:-1]   # remove outer walls of the environment (for efficiency)
        
        flattened_grid = full_grid.ravel()
        return flattened_grid
    
    def render(self, *args, **kwargs):
        """This removes the default visualization of the partially observable field of view."""
        kwargs['highlight'] = False
        return self.unwrapped.render(*args, **kwargs)

In [6]:
def wrap_env(env):
    env = Monitor(env, './video', force=True)
    return env

def gen_wrapped_env(env_name):
    return wrap_env(FlatObsWrapper(gym.make(env_name)))


def show_video():
    mp4list = glob.glob('video/*.mp4')
    if len(mp4list) > 0:
        mp4 = mp4list[0]
        video = io.open(mp4, 'r+b').read()
        encoded = base64.b64encode(video)
        display.display(HTML(data='''<video alt="test" autoplay 
                loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii'))))

In [8]:
env = gen_wrapped_env('MiniGrid-Empty-5x5-v0')
obs = env.reset()

agent = DeepSARSA(obs.shape[0], 3)

UnregisteredEnv: No registered env with id: MiniGrid-Empty-5x5-v0