In [1]:
import gym

import torch
import torch.nn as nn
import torch.nn.functional as F

import random

from tqdm.notebook import tqdm

In [2]:
# globals
STATE_SHAPE = (210, 160, 3)
ACTION_SIZE = 14
BATCH_SIZE = 32

In [3]:
# Q-Value Network
class QNetwork(nn.Module):
    def __init__(self, state_size, action_size):
        super().__init__()
        hidden_size = 8
        c,h,w = state_size
        self.net1 = nn.Sequential(nn.Conv2d(c,64,(3,3),padding=(1,1)),
                                    nn.Conv1d(64,64,(3,3),padding=(1,1)),
                                    nn.MaxPool2d((3, 3), stride=(2, 2), padding = (1,1), dilation = (1,1)),
                                    nn.Conv2d(64,128,(3,3),padding=(1,1)),
                                    nn.Conv2d(128,128,(3,3),padding=(1,1)),
                                    nn.MaxPool2d((3, 3), stride=(2, 2), padding = (1,1), dilation = (1,1)))
        
        self.net2 = nn.Sequential(nn.Linear(5120, hidden_size),
                            nn.ReLU(),
                            nn.Linear(hidden_size, hidden_size),
                            nn.ReLU(),
                            nn.Linear(hidden_size, hidden_size),
                            nn.ReLU(),
                            nn.Linear(hidden_size, ACTION_SIZE))  

    def forward(self, x):
        """Estimate q-values given state

          Args:
              state (tensor): current state, size (batch x state_size)

          Returns:
              q-values (tensor): estimated q-values, size (batch x action_size)
        """
        a = self.net1(x)
        b = a.view(-1, 128*40)
        return self.net2(b)

In [7]:
PATH = "models/asteroids-model.pt"
q_network = torch.load(PATH,map_location=torch.device('cpu'))
q_network.eval()

QNetwork(
  (net1): Sequential(
    (0): Conv2d(210, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv1d(64, 64, kernel_size=(3, 3), stride=(1,), padding=(1, 1))
    (2): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(1, 1), ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(1, 1), ceil_mode=False)
  )
  (net2): Sequential(
    (0): Linear(in_features=5120, out_features=8, bias=True)
    (1): ReLU()
    (2): Linear(in_features=8, out_features=8, bias=True)
    (3): ReLU()
    (4): Linear(in_features=8, out_features=8, bias=True)
    (5): ReLU()
    (6): Linear(in_features=8, out_features=14, bias=True)
  )
)

In [5]:
class AsteroidsAgent(object):
    def __init__(self, action_space):
        self.action_space = action_space

    def act(self, observation, reward, done):
        x = torch.FloatTensor(observation).unsqueeze(0)
        qs = q_network(x)
        values, indices = qs.squeeze().max(0)
        choice = indices.item()
        return choice

In [6]:
env = gym.make('Asteroids-v0')
env.seed(0)
agent = AsteroidsAgent(env.action_space)

episode_count = 1
reward = 0
done = False

#count = 0
for i in range(episode_count):
    ob = env.reset()
    while True:
        action = agent.act(ob, reward, done)
        env.render()
        ob, reward, done, _ = env.step(action)
        if done:
            break
        #count += 1
        #if count > 10:
        #    break
env.close()