In [1]:
import os
import numpy as np
import mxnet as mx
from model.simple_stack import SimpleStack
from utils import check_dir
from memory import Memory
from environments.SimpleEnv import SimpleEnv
from utils import create_input, translate_state
from evaluation import evaluate
from mxnet import gluon, nd, autograd

In [2]:
# training cases
order = "MXNET"
# batch size
batch_size = 512
# agent view
agent_view = 5
map_size = 20
# action max
action_max = 3
# learning rate
model_save = "./model_save/"
lr = 0.001
num_episode = 1000000
# start play
replay_start = 10000
# update step
update_step = 1000
# gamma in q-loss calculation
gamma = 0.99
# memory pool size
memory_length = 100000
# file to save train log
summary = "./{}_Reward.csv".format(order)
eval_statistics = "./{}_CSV.csv".format(order)
# the number of step it take to linearly anneal the epsilon to it min value
annealing_end = 200000
# min level of stochastically of policy (epsilon)-greedy
epsilon_min = 0.2
# temporary files
temporary_model = "./{}/{}.params".format(model_save, order)
temporary_pool = "./{}/{}.pool".format(model_save, order)

In [3]:
if os.path.exists(summary):
    os.remove(summary)
ctx = mx.gpu()
for i in ["model_save", "data_save"]:
    check_dir(i)

In [4]:
# build models
online_model = SimpleStack()
offline_model = SimpleStack()
online_model.collect_params().initialize(mx.init.MSRAPrelu(), ctx=ctx)
offline_model.collect_params().initialize(mx.init.MSRAPrelu(), ctx=ctx)
offline_model.collect_params().zero_grad()



In [5]:
# create env
env = SimpleEnv(display=False, agent_view=agent_view, map_size=map_size)
env.reset_env()
memory_pool = Memory(memory_length)
annealing = 0
total_reward = np.zeros(num_episode)
eval_result = []
loss_func = gluon.loss.L2Loss()
trainer = gluon.Trainer(offline_model.collect_params(), 'adam', {'learning_rate': lr})

In [6]:
_epoch = 0

In [None]:
_print = True
_last_dr_50 = 0
for epoch in range(num_episode):
    env.reset_env()
    finish = 0
    cum_clipped_dr = 0
    if epoch == 51:
        print("Model Structure: ")
        print(offline_model)
    if sum(env.step_count) > replay_start and _print:
        print('annealing and learning are started')
        _print = False
    while not finish:
        if sum(env.step_count) > replay_start:
            annealing += 1
        eps = np.maximum(1 - sum(env.step_count) / annealing_end, epsilon_min)
        if np.random.random() < eps:
            by = "Random"
            action = np.random.randint(0, action_max)
        else:
            by = "Model"
            data = create_input([translate_state(env.map.state())])
            data = [nd.array(i, ctx=ctx) for i in data]
            action = offline_model(data)
            action = int(nd.argmax(action, axis=1).asnumpy()[0])
        old, new, reward_get, finish = env.step(action)
        memory_pool.add(old, new, action, reward_get, finish)
        if finish and epoch > 50:
            cum_clipped_dr += env.detect_rate[-1]
            dr_50 = float(np.mean(env.detect_rate[-50:]))
            dr_all = float(np.mean(env.detect_rate))
            if epoch % 50 == 0:
                text = "DR: %f(50), %f(all), eps: %f" % (dr_50, dr_all, eps)
                print(text)
                with open(summary, "a") as f:
                    f.writelines(text + "\n")
            if epoch % 100 == 0 and annealing > replay_start:
                eval_result.extend(evaluate(ctx, offline_model, env, 5))
            # save model and replace online model each update_step
            if annealing > replay_start and annealing % update_step == 0:
                if dr_50 >= _last_dr_50:
                    _last_dr_50 = dr_50
                    offline_model.save_parameters(temporary_model)
                    online_model.load_parameters(temporary_model, ctx)
    #  train every 2 epoch
    if annealing > replay_start and epoch % 2 == 0:
        # Sample random mini batch of transitions
        if len(memory_pool.memory) > batch_size:
            bz = batch_size
        else:
            bz = len(memory_pool.memory)
        for_train = memory_pool.next_batch(bz)
        state =  [nd.array(i, ctx=ctx) for i in for_train["state"]]
        state_next = [nd.array(i, ctx=ctx) for i in for_train["state_next"]]
        finish = nd.array(for_train["finish"], ctx=ctx)
        action = nd.array(for_train["action"], ctx=ctx)
        reward = nd.array(for_train["reward"], ctx=ctx)
        with autograd.record(train_mode=True):
            q_sp = nd.max(online_model(state_next), axis=1)
            q_sp = q_sp * (nd.ones(bz, ctx=ctx) - finish)
            q_s = offline_model(state)
            q_s = nd.pick(q_s, action, 1)
            loss = nd.mean(loss_func(q_s, (reward + gamma * q_sp)))
        loss.backward()
        trainer.step(bz)
    total_reward[int(epoch) - 1] = cum_clipped_dr

[02:59:07] ../src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)


Model Structure: 
SimpleStack(
  (view): Sequential(
    (0): Conv2D(2 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): Conv2D(256 -> 128, kernel_size=(2, 2), stride=(1, 1), bias=False)
    (2): Conv2D(128 -> 128, kernel_size=(2, 2), stride=(1, 1), bias=False)
  )
  (map): Sequential(
    (0): Conv2D(3 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (2): Conv2D(256 -> 128, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (3): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (4): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (5): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), bias=False)
  )
  (decision_making): Sequential(
    (0): Dense(13953 -> 1024, Activation(sigmoid))
    (1): Dense(1024 -> 64, Activation(sigmoid))
    (2): Dense(64 -> 3, Activation(sigmoid))
  )
)
DR: 0.032650(50), 0.039723(all), eps: 0.964705
annealing and lea

DR: 0.114496(50), 0.075466(all), eps: 0.200000
DR: 0.073911(50), 0.075444(all), eps: 0.200000
DR: 0.115018(50), 0.075683(all), eps: 0.200000
DR: 0.091049(50), 0.075774(all), eps: 0.200000
DR: 0.097948(50), 0.075906(all), eps: 0.200000
DR: 0.049762(50), 0.075787(all), eps: 0.200000
DR: 0.123107(50), 0.076066(all), eps: 0.200000
DR: 0.094043(50), 0.076156(all), eps: 0.200000
DR: 0.064894(50), 0.076090(all), eps: 0.200000
DR: 0.061322(50), 0.075970(all), eps: 0.200000
DR: 0.061512(50), 0.075887(all), eps: 0.200000
DR: 0.120834(50), 0.076129(all), eps: 0.200000
DR: 0.096109(50), 0.076242(all), eps: 0.200000
DR: 0.086443(50), 0.076265(all), eps: 0.200000
DR: 0.088853(50), 0.076335(all), eps: 0.200000
DR: 0.099205(50), 0.076467(all), eps: 0.200000
DR: 0.104272(50), 0.076621(all), eps: 0.200000
DR: 0.093378(50), 0.076678(all), eps: 0.200000
DR: 0.100161(50), 0.076806(all), eps: 0.200000
DR: 0.104279(50), 0.076947(all), eps: 0.200000
DR: 0.093319(50), 0.077036(all), eps: 0.200000
DR: 0.072849(

DR: 0.135758(50), 0.096346(all), eps: 0.200000
DR: 0.140449(50), 0.096472(all), eps: 0.200000
DR: 0.059334(50), 0.096381(all), eps: 0.200000
DR: 0.080654(50), 0.096336(all), eps: 0.200000
DR: 0.112966(50), 0.096356(all), eps: 0.200000
DR: 0.073437(50), 0.096291(all), eps: 0.200000
DR: 0.035534(50), 0.096092(all), eps: 0.200000
DR: 0.052766(50), 0.095970(all), eps: 0.200000
DR: 0.055642(50), 0.095924(all), eps: 0.200000
DR: 0.055698(50), 0.095812(all), eps: 0.200000
DR: 0.050878(50), 0.095659(all), eps: 0.200000
DR: 0.084088(50), 0.095627(all), eps: 0.200000
DR: 0.074430(50), 0.095610(all), eps: 0.200000
DR: 0.100750(50), 0.095624(all), eps: 0.200000
DR: 0.096381(50), 0.095626(all), eps: 0.200000
DR: 0.054473(50), 0.095512(all), eps: 0.200000
DR: 0.181898(50), 0.095723(all), eps: 0.200000
DR: 0.103135(50), 0.095744(all), eps: 0.200000
DR: 0.061190(50), 0.095675(all), eps: 0.200000
DR: 0.093265(50), 0.095669(all), eps: 0.200000
DR: 0.103398(50), 0.095674(all), eps: 0.200000
DR: 0.095395(

DR: 0.102465(50), 0.093426(all), eps: 0.200000
DR: 0.084986(50), 0.093404(all), eps: 0.200000
DR: 0.069289(50), 0.093359(all), eps: 0.200000
DR: 0.053524(50), 0.093299(all), eps: 0.200000
DR: 0.076256(50), 0.093267(all), eps: 0.200000
DR: 0.111092(50), 0.093301(all), eps: 0.200000
DR: 0.087685(50), 0.093291(all), eps: 0.200000
DR: 0.108932(50), 0.093305(all), eps: 0.200000
DR: 0.111351(50), 0.093338(all), eps: 0.200000
DR: 0.076363(50), 0.093295(all), eps: 0.200000
DR: 0.093255(50), 0.093295(all), eps: 0.200000
DR: 0.074178(50), 0.093242(all), eps: 0.200000
DR: 0.096797(50), 0.093249(all), eps: 0.200000
DR: 0.047187(50), 0.093163(all), eps: 0.200000
DR: 0.062493(50), 0.093107(all), eps: 0.200000
DR: 0.078529(50), 0.093114(all), eps: 0.200000
DR: 0.058937(50), 0.093051(all), eps: 0.200000
DR: 0.095939(50), 0.093053(all), eps: 0.200000
DR: 0.084739(50), 0.093038(all), eps: 0.200000
DR: 0.087961(50), 0.093014(all), eps: 0.200000
DR: 0.090867(50), 0.093010(all), eps: 0.200000
DR: 0.073495(

DR: 0.099165(50), 0.092367(all), eps: 0.200000
DR: 0.116600(50), 0.092401(all), eps: 0.200000
DR: 0.088815(50), 0.092397(all), eps: 0.200000
DR: 0.062896(50), 0.092356(all), eps: 0.200000
DR: 0.103598(50), 0.092371(all), eps: 0.200000
DR: 0.106595(50), 0.092391(all), eps: 0.200000
DR: 0.102047(50), 0.092427(all), eps: 0.200000
