<a href="https://colab.research.google.com/github/zzmtsvv/RL-with-gym/blob/main/icehockey.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install gym
! pip install gym[atari]

In [None]:
! unzip /content/ROMS.zip

In [None]:
!python -m atari_py.import_roms /content/ROMS

In [4]:
import gym
from gym.wrappers import TransformObservation
import torch
from torch import nn
from copy import deepcopy
import numpy as np
import random
import torch.nn.functional as F
import matplotlib.pyplot as plt
from collections import deque, namedtuple

In [16]:
BUFFER_SIZE = 1000
BATCH_SIZE = 128
GAMMA = 0.99
TARGET_UPDATE = 500
SEED = 252

In [6]:
env = gym.make('IceHockey-ram-v0')
env.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
random.seed(SEED)

In [7]:
env = TransformObservation(env, lambda x: x / 255.0)

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [17]:
class BestAgent:
  def __init__(self, model):
    self.model = deepcopy(model)
  
  def act(self, state):
    state = torch.tensor(state).to(device).float()
    with torch.no_grad():
      q_values = self.model(state.unsqueeze(0))
    return np.argmax(q_values.cpu().data.numpy())

In [18]:
class ReplayMemory(object):
  def __init__(self, capacity):
    self.capacity = capacity
    self.memory = deque(maxlen=capacity)
    self.transition = namedtuple('Transition', ('state', 'action',
                                                'next_state', 'reward'))
    self.position = 0
  
  def __len__(self):
    return len(self.memory)
  
  def push(self, *args):
    self.memory.append(self.transition(*args))
  
  def sample(self, batch_size):
    return random.sample(self.memory, batch_size)


class Buffering:
  def __init__(self, capacity):
    self.capacity = capacity
    self.memory = []
    self.position = -1
  
  def __len__(self):
    return len(self.memory)
  
  def add(self, elem):
    if len(self.memory) < self.capacity:
      self.memory.append(None)
    new_pos = (self.position + 1) % self.capacity
    self.memory[new_pos] = elem
    self.position = new_pos
    
  def sample(self, batch_size):
    return random.sample(self.memory, batch_size)

  def __getitem__(self, key):
    return self.memory[(self.position + 1 + key) % self.capacity]

In [28]:
class ArseNet(nn.Module):
  def __init__(self, input_dim, output_dim):
    super(ArseNet, self).__init__()

    self.online = nn.Sequential(
        nn.Linear(input_dim, 64),
        nn.ReLU(),
        nn.Linear(64, 64),
        nn.ReLU(),
        nn.Linear(64, 48),
        nn.ReLU(),
        nn.Linear(48, 48),
        nn.ReLU(),
        nn.Linear(48, 32),
        nn.ReLU(),
        nn.Linear(32, 32),
        nn.ReLU(),
        nn.Linear(32, output_dim)
    )
    
  def forward(self, x):
    return self.online(x)

In [33]:
class Agent:
  def __init__(self, state_dim, action_dim):
    self.net = ArseNet(state_dim, action_dim).to(device)
    self.target = deepcopy(self.net).to(device)
    
    for p in self.target.parameters():
      p.requires_grad = False
    
    self.optimizer = torch.optim.Adam(self.net.parameters(), lr=3e-4)
  
  def update(self, batch):
    states, actions, rewards, next_states, dones = zip(*batch)
    states = torch.from_numpy(np.array(states)).float().to(device)
    actions = torch.from_numpy(np.array(actions)).to(device).unsqueeze(1)
    rewards = torch.from_numpy(np.array(rewards)).float().to(device).unsqueeze(1)
    next_states = torch.from_numpy(np.array(next_states)).float().to(device)
    dones = torch.from_numpy(np.array(dones)).to(device).unsqueeze(1)

    with torch.no_grad():
      argmax = self.net(next_states).detach().max(1)[1].unsqueeze(1)
      target = rewards + (GAMMA * self.target(next_states).detach().gather(1, argmax))*(~dones)
    
    q_current = self.net(states).gather(1, actions)
    self.optimizer.zero_grad()
    loss = F.mse_loss(target, q_current)
    loss.backward()
    self.optimizer.step()
  
  def act(self, state):
    state = torch.tensor(state).to(device).float()
    with torch.no_grad():
      q_values = self.net(state.unsqueeze(0))
    return np.argmax(q_values.cpu().data.numpy())
  
  def update_target(self):
    self.target = deepcopy(self.net)

In [34]:
def eps_greedy(env, agent, state, eps):
  if random.random() < eps:
    return env.action_space.sample()
  return agent.act(state)


def current_result(iteration, rewards, eps, output_period=100):
  mean_ = np.mean(rewards)
  max_ = np.max(rewards)
  min_ = np.min(rewards)
  if not iteration % output_period:
    print(f'\episode {iteration} eps={eps} mean={mean_} max={max_} min={min_}')
  return mean_

In [35]:
agent = Agent(128, 18)
buf = Buffering(BUFFER_SIZE)

episodes = 1000
eps = 1
eps_coeff = 0.995
net_updates = 0

rewards = deque(maxlen=100)
best_mean = -1000
best_model = None
overall_rewards = [None] * episodes

In [None]:
for episode in range(episodes):
  state = env.reset()
  done = False
  total_reward = 0

  while not done:
    action = eps_greedy(env, agent, state, eps)
    next_state, reward, done, _ = env.step(action)
    total_reward += reward
    reward += 300 * (GAMMA * np.mean(next_state) - np.mean(state))

    buf.add((state, action, reward, next_state, done))
    if len(buf) >= BATCH_SIZE:
      agent.update(buf.sample(BATCH_SIZE))
      net_updates += 1
    if not net_updates % TARGET_UPDATE:
      agent.update_target()
    
    state = next_state
  eps *= eps_coeff
  rewards.append(total_reward)
  overall_rewards[episode] = total_reward
  
  mean_reward = current_result(episode, rewards, eps)
  if mean_reward > best_mean:
    best_model = deepcopy(agent.net)
    best_mean = mean_reward

In [None]:
plt.figure(figsize=(15, 7))
plt.plot(range(1, episodes + 1), overall_rewards)
plt.ylabel('Reward')
plt.xlabel('Episodes')
plt.show()

In [39]:
!pip install gym pyvirtualdisplay > /dev/null 2>&1
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1

In [40]:
from gym.wrappers import Monitor
import glob
import io
import base64
from IPython.display import HTML
from IPython import display as ipythondisplay

In [41]:
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

<pyvirtualdisplay.display.Display at 0x7f76e27a1f90>

In [42]:
def show_video():
  mp4list = glob.glob('video/*.mp4')
  if len(mp4list) > 0:
    mp4 = mp4list[0]
    video = io.open(mp4, 'r+b').read()
    encoded = base64.b64encode(video)
    ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
                loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii'))))
  else: 
    print("Could not find video")
    

def wrap_env(env):
  env = Monitor(env, './video', force=True)
  return env

In [None]:
env = wrap_env(env)
state = env.reset()
done = False
total_reward = 0
ag = BestAgent(best_model)

while not done:
    env.render()
    action = ag.act(state)
    state, reward, done, _ = env.step(action)
    total_reward += reward
env.close()
print(total_reward)
show_video()

In [None]:
env.close()