In [1]:
import erlyx

In [2]:
from erlyx.environment import GymAtariBWEnvironment

In [3]:
from erlyx.policies import PytorchPolicy
from erlyx.agents import BaseAgent
from erlyx.datasets import SequenceDataset
from erlyx.algorithms.ddqn import DoubleDeepQLearner

In [4]:
from erlyx.callbacks import BaseCallback
from erlyx.callbacks.recorders import TransitionRecorder, RewardRecorder
from erlyx.callbacks.updaters import OnlineUpdater, LinearEpsilonDecay
from erlyx.callbacks.checkpoint import PytorchCheckPointer

In [5]:
import numpy as np
import torch
from datetime import datetime

from collections import deque
from pathlib import Path
import os
from time import sleep
from collections import OrderedDict

## Agent

In [6]:
class EpsilonGreedyHistoryAgent(BaseAgent):
    def __init__(self, policy, epsilon, obs_shape):
        super(EpsilonGreedyHistoryAgent, self).__init__(policy=policy)
        self.epsilon = epsilon
        self._obs_shape = obs_shape
        self._memory_buffer = deque(maxlen=4)
        
    def reset_memory(self):
        self._memory_buffer = deque([np.zeros(shape=self._obs_shape)]*4, maxlen=4)
        
    def select_action(self, observation):
        self._memory_buffer.append(observation)
        observation = np.asarray(list(self._memory_buffer))
        if np.random.uniform() < self.epsilon:
            return np.random.choice(self.action_space)
        distribution = self.policy.get_distribution(observation)
        return np.argmax(distribution)

In [7]:
class AgentMemoryReseter(BaseCallback):
    def __init__(self, agent):
        self._agent = agent
        
    def on_episode_begin(self, initial_observation):
        self._agent.reset_memory()

## Policy

In [8]:
class DuelingNetwork(torch.nn.Module):
    def __init__(self):
        super(DuelingNetwork, self).__init__()
        self._conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(4, 32, kernel_size=8, stride=4),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 64, kernel_size=3, stride=2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 64, kernel_size=3, stride=1),
            torch.nn.ReLU(),
        )
        self._linear_advantage = torch.nn.Sequential(
            torch.nn.Linear(64*49, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128,  18)
        )
        self._linear_value = torch.nn.Sequential(
            torch.nn.Linear(64*49, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128,  1)
        )
        
                
    def forward(self, x):
        if len(x.shape) == 3:
            x = x.unsqueeze(0)
        output = self._conv1(x)
        output = output.view(x.shape[0], -1)
        advantage = self._linear_advantage(output)
        value = self._linear_value(output)
        output = value + advantage - advantage.mean()
        return output

In [9]:
class CNNPolicy(PytorchPolicy):
    def __init__(self):
        self._model = DuelingNetwork()
        
    @property
    def model(self):
        return self._model
    
    def process_observation(self, observation):
        tensor = torch.Tensor(observation/128. - 1.)/0.35 # divided by 0.35 because of standard deviation
        return tensor
        
    def get_distribution(self, state):
        with torch.no_grad():
            self.model.eval()
            q_values = self.model(self.process_observation(state))
            distribution = torch.nn.functional.softmax(q_values, dim=1)
        return distribution.data.cpu().numpy().reshape(-1)
    
    def num_actions(self):
        return 18


# Train

In [10]:
checkpoint = torch.load('20200830213341/latest_checkpoint.sd')

In [11]:
img_hw = (84, 84)

# policy
policy = CNNPolicy()
policy.model.load_state_dict(checkpoint['model'])

# agent
agent = EpsilonGreedyHistoryAgent(policy=policy, epsilon=0.01, obs_shape=img_hw)

# dataset
dataset_maxlen = 1_000_000
dtypes = {o: np.uint8 for o in ['state', 'action', 'done']}
dtypes['reward'] = np.float32
dataset = SequenceDataset(input_shape=img_hw, max_length=dataset_maxlen, history_length=4, dtypes=dtypes)

# learner
learner = DoubleDeepQLearner(32, policy, 1e-4/4, 30_000, loss_func=torch.nn.SmoothL1Loss())
learner.optimizer.load_state_dict(checkpoint['optimizer'])

# Load Dataset

In [12]:
import h5py
hf = h5py.File('20200830213341/dataset.h5', 'r')

dataset._actions = np.array(hf.get('actions'))

dataset._rewards = np.array(hf.get('rewards'))

dataset._states = np.array(hf.get('states'))

dataset._dones = np.array(hf.get('dones'))

In [13]:
dataset._position = 8000

# Train

In [14]:
log_folder = Path(datetime.now().strftime("%Y%m%d%H%M%S"))
os.mkdir(log_folder)
print(f'logging to folder {log_folder}')

logging to folder 20200904195957


In [19]:
train_callbacks = [
    # base
    AgentMemoryReseter(agent=agent),
    TransitionRecorder(dataset), 
    OnlineUpdater(learner, dataset, 0, 1),
    LinearEpsilonDecay(agent, 0.01, 0.01, 1e-6),
    # logging
    RewardRecorder(log_folder/'reward_log'),
    PytorchCheckPointer(GymAtariBWEnvironment('Seaquest-v0', simplified_reward=False, img_hw=(84,84)), agent, 
                        learner, 25_000, 0, log_folder, log_folder/'checkpoint_log')
]

In [20]:
erlyx.run_steps(GymAtariBWEnvironment('Seaquest-v0', img_hw=(84,84)), agent, 12_000_000, train_callbacks)

HBox(children=(FloatProgress(value=0.0, max=12000000.0), HTML(value='')))

new best wachin!
[Evaluation] mean: 8134.0, rewards: [8390.0, 2280.0, 4720.0, 5340.0, 19940.0]
[Evaluation] mean: 7280.0, rewards: [2590.0, 17370.0, 700.0, 13150.0, 2590.0]
new best wachin!
[Evaluation] mean: 8576.0, rewards: [3090.0, 11430.0, 17900.0, 8060.0, 2400.0]
[Evaluation] mean: 8038.0, rewards: [7770.0, 2850.0, 11070.0, 9650.0, 8850.0]
[Evaluation] mean: 5262.0, rewards: [2060.0, 3200.0, 8880.0, 6970.0, 5200.0]
new best wachin!
[Evaluation] mean: 10910.0, rewards: [8500.0, 16020.0, 560.0, 17450.0, 12020.0]
new best wachin!
[Evaluation] mean: 13658.0, rewards: [14840.0, 14950.0, 14090.0, 15630.0, 8780.0]
[Evaluation] mean: 11010.0, rewards: [5420.0, 27990.0, 6600.0, 7260.0, 7780.0]
[Evaluation] mean: 5924.0, rewards: [3210.0, 2820.0, 5990.0, 8200.0, 9400.0]
[Evaluation] mean: 10276.0, rewards: [7880.0, 7390.0, 7040.0, 21960.0, 7110.0]
[Evaluation] mean: 7358.0, rewards: [6520.0, 3070.0, 13020.0, 6770.0, 7410.0]
[Evaluation] mean: 11314.0, rewards: [18690.0, 6300.0, 9710.0, 1131

[Evaluation] mean: 15212.0, rewards: [24540.0, 28990.0, 5190.0, 3120.0, 14220.0]
[Evaluation] mean: 16148.0, rewards: [36090.0, 13820.0, 8120.0, 7210.0, 15500.0]
new best wachin!
[Evaluation] mean: 27346.0, rewards: [18160.0, 39160.0, 25070.0, 28840.0, 25500.0]
[Evaluation] mean: 8666.0, rewards: [14900.0, 15860.0, 7900.0, 2460.0, 2210.0]
[Evaluation] mean: 8086.0, rewards: [13670.0, 5420.0, 6240.0, 7210.0, 7890.0]
[Evaluation] mean: 12578.0, rewards: [5470.0, 5630.0, 16480.0, 17210.0, 18100.0]
[Evaluation] mean: 9668.0, rewards: [12580.0, 6660.0, 15690.0, 920.0, 12490.0]
[Evaluation] mean: 17772.0, rewards: [7660.0, 15490.0, 30200.0, 24210.0, 11300.0]
[Evaluation] mean: 9112.0, rewards: [6220.0, 12910.0, 4270.0, 5550.0, 16610.0]
[Evaluation] mean: 18944.0, rewards: [5770.0, 28670.0, 23420.0, 16730.0, 20130.0]
[Evaluation] mean: 9760.0, rewards: [9670.0, 15030.0, 12440.0, 6340.0, 5320.0]
[Evaluation] mean: 4054.0, rewards: [2430.0, 1820.0, 6110.0, 6320.0, 3590.0]
[Evaluation] mean: 907

[Evaluation] mean: 10478.0, rewards: [20630.0, 12960.0, 4360.0, 5830.0, 8610.0]
[Evaluation] mean: 13880.0, rewards: [15380.0, 9780.0, 21380.0, 13670.0, 9190.0]
[Evaluation] mean: 11880.0, rewards: [4800.0, 8130.0, 14230.0, 12290.0, 19950.0]
[Evaluation] mean: 12850.0, rewards: [22850.0, 16130.0, 2310.0, 6150.0, 16810.0]
[Evaluation] mean: 6954.0, rewards: [4190.0, 3000.0, 6680.0, 15540.0, 5360.0]
[Evaluation] mean: 9624.0, rewards: [4650.0, 12770.0, 7380.0, 17890.0, 5430.0]
[Evaluation] mean: 8474.0, rewards: [16610.0, 2150.0, 14500.0, 3020.0, 6090.0]
[Evaluation] mean: 9610.0, rewards: [11330.0, 14580.0, 6570.0, 12980.0, 2590.0]
[Evaluation] mean: 11138.0, rewards: [4850.0, 15350.0, 16890.0, 15800.0, 2800.0]
[Evaluation] mean: 9742.0, rewards: [16890.0, 7690.0, 5690.0, 3730.0, 14710.0]
[Evaluation] mean: 16648.0, rewards: [2490.0, 29610.0, 15800.0, 17250.0, 18090.0]
[Evaluation] mean: 14248.0, rewards: [26270.0, 9970.0, 15520.0, 5680.0, 13800.0]
[Evaluation] mean: 13782.0, rewards: [

[Evaluation] mean: 8814.0, rewards: [7680.0, 17350.0, 2240.0, 8690.0, 8110.0]
[Evaluation] mean: 13658.0, rewards: [18240.0, 4750.0, 16220.0, 17690.0, 11390.0]
[Evaluation] mean: 5834.0, rewards: [2390.0, 2000.0, 2230.0, 4570.0, 17980.0]
[Evaluation] mean: 8814.0, rewards: [2900.0, 12250.0, 12950.0, 13190.0, 2780.0]
[Evaluation] mean: 8824.0, rewards: [2150.0, 2920.0, 13180.0, 5660.0, 20210.0]
[Evaluation] mean: 16942.0, rewards: [12880.0, 38360.0, 16360.0, 4610.0, 12500.0]
[Evaluation] mean: 5432.0, rewards: [2770.0, 420.0, 380.0, 7860.0, 15730.0]


KeyboardInterrupt: 

## Persist dataset

In [38]:
import h5py

In [39]:
hf = h5py.File(log_folder/'dataset.h5', 'w')
hf.create_dataset('states', data=dataset._states, dtype=np.uint8)
hf.create_dataset('actions', data=dataset._actions, dtype=np.uint8)
hf.create_dataset('rewards', data=dataset._rewards, dtype=np.float32)
hf.create_dataset('dones', data=dataset._dones, dtype=np.uint8)
print(f'latest dataset position: {dataset._position}')
hf.close()

position = 816565

In [37]:
dataset._position

816565

## Eval

In [35]:
env = GymAtariBWEnvironment('Seaquest-v0', repeat=4, simplified_reward=False, img_hw=(84, 84))
rr = RewardRecorder()
tr = Tracker()
erlyx.run_episodes(env, agent, 1, [AgentMemoryReseter(agent), rr, tr])

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


