In [1]:
import os
import numpy as np
import mxnet as mx
from utils import check_dir
from memory import Memory
from utils import create_input, translate_state
from evaluation_mxnet import evaluate
from mxnet import gluon, nd, autograd
from environments.SimpleEnv import SimpleEnv

In [2]:
# training cases
order = "MXNET_view_only"
# batch size
batch_size = 512
# agent view
agent_view = 5
map_size = 20
# action max
action_max = 3
# learning rate
model_save = "./model_save/"
lr = 0.001
num_episode = 1000000
# start play
replay_start = 10000
# update step
update_step = 1000
# gamma in q-loss calculation
gamma = 0.99
# memory pool size
memory_length = 100000
# file to save train log
summary = "./{}_Reward.csv".format(order)
eval_statistics = "./{}_CSV.csv".format(order)
# the number of step it take to linearly anneal the epsilon to it min value
annealing_end = 200000
# min level of stochastically of policy (epsilon)-greedy
epsilon_min = 0.2
# temporary files
temporary_model = "./{}/{}.params".format(model_save, order)
temporary_pool = "./{}/{}.pool".format(model_save, order)

In [3]:
if os.path.exists(summary):
    os.remove(summary)
ctx = mx.gpu()
for i in ["model_save", "data_save"]:
    check_dir(i)

In [4]:
# build models
from model.simple_stack import SimpleStack
online_model = SimpleStack()
offline_model = SimpleStack()
online_model.collect_params().initialize(mx.init.MSRAPrelu(), ctx=ctx)
offline_model.collect_params().initialize(mx.init.MSRAPrelu(), ctx=ctx)
offline_model.collect_params().zero_grad()



In [5]:
# create env
env = SimpleEnv(display=False, agent_view=agent_view, map_size=map_size)
env.reset_env()
memory_pool = Memory(memory_length)
annealing = 0
total_reward = np.zeros(num_episode)
eval_result = []
loss_func = gluon.loss.L2Loss()
trainer = gluon.Trainer(offline_model.collect_params(), 'adam', {'learning_rate': lr})

In [None]:
_print = True
best = 0
_all = 0
_update = 0
for epoch in range(num_episode):
    env.reset_env()
    finish = 0
    cum_clipped_dr = 0
    if epoch == 51:
        print("Model Structure: ")
        print(offline_model)
    if sum(env.step_count) > replay_start and _print:
        print('annealing and learning are started')
        _print = False
    while not finish:
        _all += 1
        if sum(env.step_count) > replay_start:
            annealing += 1
        eps = np.maximum(1 - sum(env.step_count) / annealing_end, epsilon_min)
        if np.random.random() < eps:
            by = "Random"
            action = np.random.randint(0, action_max)
        else:
            by = "Model"
            data = create_input([translate_state(env.map.state())])
            data = [nd.array(i, ctx=ctx) for i in data]
            action = offline_model(data)
            action = int(nd.argmax(action, axis=1).asnumpy()[0])
        old, new, reward_get, finish = env.step(action)
        memory_pool.add(old, new, action, reward_get, finish)
        if finish and epoch > 50:
            cum_clipped_dr += env.detect_rate[-1]
            dr_50 = float(np.mean(env.detect_rate[-50:]))
            dr_all = float(np.mean(env.detect_rate))
            if epoch % 50 == 0:
                text = "DR: %f(50), %f(all), eps: %f" % (dr_50, dr_all, eps)
                print(text)
                with open(summary, "a") as f:
                    f.writelines(text + "\n")
            if epoch % 100 == 0 and annealing > replay_start:
                eval_result.extend(evaluate(ctx, offline_model, env, 5))
            # save model and replace online model each update_step
            if annealing > replay_start and annealing % update_step == 0:
                
                offline_model.save_parameters(temporary_model)
                online_model.load_parameters(temporary_model, ctx)
                if best < dr_all:
                    best = dr_all
                    offline_model.save_parameters("./best.params")
    #  train every 2 epoch
    if annealing > replay_start and epoch % 2 == 0:
        _update += 1
        # Sample random mini batch of transitions
        if len(memory_pool.memory) > batch_size:
            bz = batch_size
        else:
            bz = len(memory_pool.memory)
        for_train = memory_pool.next_batch(bz)
        with autograd.record(train_mode=True):
            _state =[nd.array(i, ctx=ctx) for i in for_train["state"]]
            _state_next = [nd.array(i, ctx=ctx) for i in for_train["state_next"]]
            _finish = nd.array(for_train["finish"], ctx=ctx)
            _action = nd.array(for_train["action"], ctx=ctx)
            _reward = nd.array(for_train["reward"], ctx=ctx)
            q_sp = nd.max(online_model(_state_next), axis=1)
            q_sp = q_sp * (nd.ones(bz, ctx=ctx) - _finish)
            q_s_array = offline_model(_state)
            q_s = nd.pick(q_s_array, _action, 1)
            loss = nd.mean(loss_func(q_s, (_reward + gamma * q_sp)))
        loss.backward()
        trainer.step(bz)
    total_reward[int(epoch) - 1] = cum_clipped_dr

[02:10:17] ../src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)


Model Structure: 
SimpleStack(
  (view): Sequential(
    (0): Conv2D(1 -> 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): Conv2D(256 -> 128, kernel_size=(2, 2), stride=(1, 1), bias=False)
    (2): Conv2D(128 -> 128, kernel_size=(2, 2), stride=(1, 1), bias=False)
  )
  (decision_making): Sequential(
    (0): Dense(1153 -> 1024, Activation(sigmoid))
    (1): Dense(1024 -> 64, Activation(sigmoid))
    (2): Dense(64 -> 3, Activation(sigmoid))
  )
)
DR: 0.065130(50), 0.058388(all), eps: 0.950000
annealing and learning are started
DR: 0.059505(50), 0.058758(all), eps: 0.925000
DR: 0.071570(50), 0.061945(all), eps: 0.900000
DR: 0.094030(50), 0.068337(all), eps: 0.875000
DR: 0.089979(50), 0.071932(all), eps: 0.850000
DR: 0.082802(50), 0.072799(all), eps: 0.822500
DR: 0.108179(50), 0.077156(all), eps: 0.797500
DR: 0.073254(50), 0.075896(all), eps: 0.770000
DR: 0.081338(50), 0.076429(all), eps: 0.745000
DR: 0.104324(50), 0.078464(all), eps: 0.717500
DR: 0.107616(50), 0.080830(all), 

DR: 0.783434(50), 0.409068(all), eps: 0.200000
DR: 0.705363(50), 0.411059(all), eps: 0.200000
DR: 0.762520(50), 0.413054(all), eps: 0.200000
DR: 0.735813(50), 0.415207(all), eps: 0.200000
DR: 0.678072(50), 0.416682(all), eps: 0.200000
DR: 0.718460(50), 0.418690(all), eps: 0.200000
DR: 0.650929(50), 0.419978(all), eps: 0.200000
DR: 0.751345(50), 0.422014(all), eps: 0.200000
DR: 0.582725(50), 0.422895(all), eps: 0.200000
DR: 0.517519(50), 0.423415(all), eps: 0.200000
DR: 0.599048(50), 0.424367(all), eps: 0.200000
DR: 0.823523(50), 0.426625(all), eps: 0.200000
DR: 0.671213(50), 0.427936(all), eps: 0.200000
DR: 0.640748(50), 0.429336(all), eps: 0.200000
DR: 0.655605(50), 0.430535(all), eps: 0.200000
DR: 0.700110(50), 0.432255(all), eps: 0.200000
DR: 0.676871(50), 0.433537(all), eps: 0.200000
DR: 0.808684(50), 0.435763(all), eps: 0.200000
DR: 0.716969(50), 0.437221(all), eps: 0.200000
DR: 0.700738(50), 0.438745(all), eps: 0.200000
DR: 0.723085(50), 0.440203(all), eps: 0.200000
DR: 0.776349(

DR: 0.516919(50), 0.537913(all), eps: 0.200000
DR: 0.711185(50), 0.538396(all), eps: 0.200000
DR: 0.560779(50), 0.538517(all), eps: 0.200000
DR: 0.697623(50), 0.538958(all), eps: 0.200000
DR: 0.671640(50), 0.539405(all), eps: 0.200000
DR: 0.635633(50), 0.539670(all), eps: 0.200000
DR: 0.747393(50), 0.540312(all), eps: 0.200000
DR: 0.649137(50), 0.540610(all), eps: 0.200000
DR: 0.635736(50), 0.540886(all), eps: 0.200000
DR: 0.684361(50), 0.541276(all), eps: 0.200000
DR: 0.664185(50), 0.541680(all), eps: 0.200000
DR: 0.662233(50), 0.542007(all), eps: 0.200000
DR: 0.639060(50), 0.542392(all), eps: 0.200000
DR: 0.617350(50), 0.542594(all), eps: 0.200000
DR: 0.719687(50), 0.543192(all), eps: 0.200000
DR: 0.825776(50), 0.543949(all), eps: 0.200000
DR: 0.728197(50), 0.544563(all), eps: 0.200000
DR: 0.705541(50), 0.544991(all), eps: 0.200000
DR: 0.781050(50), 0.545712(all), eps: 0.200000
DR: 0.749723(50), 0.546252(all), eps: 0.200000
DR: 0.665938(50), 0.546673(all), eps: 0.200000
DR: 0.612073(

DR: 0.734481(50), 0.604963(all), eps: 0.200000
DR: 0.780160(50), 0.605358(all), eps: 0.200000
DR: 0.770293(50), 0.605662(all), eps: 0.200000
DR: 0.713029(50), 0.605923(all), eps: 0.200000
DR: 0.679190(50), 0.606057(all), eps: 0.200000
DR: 0.698195(50), 0.606297(all), eps: 0.200000
DR: 0.757001(50), 0.606573(all), eps: 0.200000
DR: 0.820871(50), 0.607032(all), eps: 0.200000
DR: 0.733196(50), 0.607261(all), eps: 0.200000
DR: 0.792287(50), 0.607632(all), eps: 0.200000
DR: 0.663449(50), 0.607733(all), eps: 0.200000
DR: 0.692117(50), 0.607957(all), eps: 0.200000
DR: 0.794420(50), 0.608293(all), eps: 0.200000
DR: 0.736937(50), 0.608595(all), eps: 0.200000
DR: 0.665329(50), 0.608697(all), eps: 0.200000
DR: 0.721139(50), 0.608969(all), eps: 0.200000
DR: 0.693564(50), 0.609121(all), eps: 0.200000
DR: 0.788745(50), 0.609512(all), eps: 0.200000
DR: 0.732173(50), 0.609731(all), eps: 0.200000
DR: 0.766985(50), 0.610080(all), eps: 0.200000
DR: 0.761789(50), 0.610350(all), eps: 0.200000
DR: 0.733507(

DR: 0.738209(50), 0.635090(all), eps: 0.200000
DR: 0.803240(50), 0.635322(all), eps: 0.200000
DR: 0.757343(50), 0.635539(all), eps: 0.200000
DR: 0.751840(50), 0.635699(all), eps: 0.200000
DR: 0.813481(50), 0.635993(all), eps: 0.200000
DR: 0.710937(50), 0.636095(all), eps: 0.200000
DR: 0.765827(50), 0.636321(all), eps: 0.200000
DR: 0.789039(50), 0.636529(all), eps: 0.200000
DR: 0.826254(50), 0.636837(all), eps: 0.200000
DR: 0.807833(50), 0.637070(all), eps: 0.200000
DR: 0.820006(50), 0.637368(all), eps: 0.200000
DR: 0.794334(50), 0.637581(all), eps: 0.200000
DR: 0.810900(50), 0.637865(all), eps: 0.200000
DR: 0.848325(50), 0.638150(all), eps: 0.200000
DR: 0.865289(50), 0.638506(all), eps: 0.200000
DR: 0.799128(50), 0.638722(all), eps: 0.200000
DR: 0.846859(50), 0.639048(all), eps: 0.200000
DR: 0.758765(50), 0.639209(all), eps: 0.200000
DR: 0.724139(50), 0.639372(all), eps: 0.200000
DR: 0.719787(50), 0.639480(all), eps: 0.200000
DR: 0.791578(50), 0.639726(all), eps: 0.200000
DR: 0.730507(

DR: 0.709957(50), 0.656912(all), eps: 0.200000
DR: 0.642256(50), 0.656933(all), eps: 0.200000
DR: 0.745959(50), 0.657031(all), eps: 0.200000
DR: 0.650673(50), 0.657040(all), eps: 0.200000
DR: 0.565756(50), 0.656940(all), eps: 0.200000
DR: 0.613770(50), 0.656899(all), eps: 0.200000
DR: 0.691597(50), 0.656937(all), eps: 0.200000
DR: 0.655865(50), 0.656908(all), eps: 0.200000
DR: 0.599607(50), 0.656845(all), eps: 0.200000
DR: 0.761308(50), 0.656977(all), eps: 0.200000
DR: 0.659761(50), 0.656980(all), eps: 0.200000
DR: 0.707587(50), 0.657072(all), eps: 0.200000
DR: 0.653154(50), 0.657068(all), eps: 0.200000
DR: 0.654242(50), 0.657072(all), eps: 0.200000
DR: 0.656244(50), 0.657071(all), eps: 0.200000
DR: 0.658937(50), 0.657096(all), eps: 0.200000
DR: 0.751162(50), 0.657197(all), eps: 0.200000
DR: 0.692346(50), 0.657242(all), eps: 0.200000
DR: 0.707807(50), 0.657297(all), eps: 0.200000
DR: 0.721046(50), 0.657385(all), eps: 0.200000
DR: 0.701244(50), 0.657432(all), eps: 0.200000
DR: 0.707732(

DR: 0.650688(50), 0.668105(all), eps: 0.200000
DR: 0.705519(50), 0.668140(all), eps: 0.200000
DR: 0.624078(50), 0.668124(all), eps: 0.200000
DR: 0.610934(50), 0.668072(all), eps: 0.200000
DR: 0.598264(50), 0.668031(all), eps: 0.200000
DR: 0.631708(50), 0.667998(all), eps: 0.200000
DR: 0.566144(50), 0.667928(all), eps: 0.200000
DR: 0.647911(50), 0.667910(all), eps: 0.200000
DR: 0.665813(50), 0.667886(all), eps: 0.200000
DR: 0.644503(50), 0.667864(all), eps: 0.200000
DR: 0.665922(50), 0.667875(all), eps: 0.200000
DR: 0.657846(50), 0.667866(all), eps: 0.200000
DR: 0.671431(50), 0.667872(all), eps: 0.200000
DR: 0.635804(50), 0.667843(all), eps: 0.200000
DR: 0.596751(50), 0.667767(all), eps: 0.200000
DR: 0.533228(50), 0.667646(all), eps: 0.200000
DR: 0.744461(50), 0.667725(all), eps: 0.200000
DR: 0.666311(50), 0.667724(all), eps: 0.200000
DR: 0.633685(50), 0.667696(all), eps: 0.200000
DR: 0.610135(50), 0.667644(all), eps: 0.200000
DR: 0.690449(50), 0.667679(all), eps: 0.200000
DR: 0.619312(

DR: 0.773189(50), 0.666343(all), eps: 0.200000
DR: 0.612216(50), 0.666327(all), eps: 0.200000
DR: 0.746534(50), 0.666390(all), eps: 0.200000
DR: 0.675655(50), 0.666423(all), eps: 0.200000
DR: 0.771025(50), 0.666505(all), eps: 0.200000
DR: 0.705573(50), 0.666528(all), eps: 0.200000
DR: 0.694894(50), 0.666550(all), eps: 0.200000
DR: 0.744019(50), 0.666637(all), eps: 0.200000
DR: 0.741862(50), 0.666695(all), eps: 0.200000
DR: 0.661016(50), 0.666717(all), eps: 0.200000
DR: 0.684165(50), 0.666730(all), eps: 0.200000
DR: 0.815970(50), 0.666872(all), eps: 0.200000
DR: 0.652353(50), 0.666861(all), eps: 0.200000
DR: 0.635820(50), 0.666856(all), eps: 0.200000
DR: 0.732900(50), 0.666908(all), eps: 0.200000
DR: 0.711693(50), 0.666952(all), eps: 0.200000
DR: 0.785021(50), 0.667044(all), eps: 0.200000
DR: 0.670441(50), 0.667058(all), eps: 0.200000
DR: 0.623087(50), 0.667025(all), eps: 0.200000
DR: 0.736719(50), 0.667059(all), eps: 0.200000
DR: 0.677413(50), 0.667067(all), eps: 0.200000
DR: 0.685626(

DR: 0.742320(50), 0.672028(all), eps: 0.200000
DR: 0.714659(50), 0.672057(all), eps: 0.200000
DR: 0.699243(50), 0.672093(all), eps: 0.200000
DR: 0.801863(50), 0.672182(all), eps: 0.200000
DR: 0.666520(50), 0.672200(all), eps: 0.200000
DR: 0.716239(50), 0.672230(all), eps: 0.200000
DR: 0.746674(50), 0.672291(all), eps: 0.200000
DR: 0.628598(50), 0.672261(all), eps: 0.200000
DR: 0.712289(50), 0.672311(all), eps: 0.200000
DR: 0.704588(50), 0.672333(all), eps: 0.200000
DR: 0.637960(50), 0.672315(all), eps: 0.200000
DR: 0.772447(50), 0.672384(all), eps: 0.200000
DR: 0.606030(50), 0.672360(all), eps: 0.200000
DR: 0.724831(50), 0.672395(all), eps: 0.200000
DR: 0.704618(50), 0.672435(all), eps: 0.200000
DR: 0.614401(50), 0.672396(all), eps: 0.200000
DR: 0.693000(50), 0.672415(all), eps: 0.200000
DR: 0.716647(50), 0.672445(all), eps: 0.200000
DR: 0.748496(50), 0.672516(all), eps: 0.200000
DR: 0.680549(50), 0.672522(all), eps: 0.200000
DR: 0.697215(50), 0.672556(all), eps: 0.200000
DR: 0.686941(

DR: 0.720279(50), 0.675241(all), eps: 0.200000
DR: 0.712320(50), 0.675277(all), eps: 0.200000
DR: 0.688746(50), 0.675285(all), eps: 0.200000
DR: 0.684875(50), 0.675292(all), eps: 0.200000
DR: 0.688963(50), 0.675301(all), eps: 0.200000
DR: 0.671673(50), 0.675314(all), eps: 0.200000
DR: 0.600421(50), 0.675268(all), eps: 0.200000
DR: 0.693551(50), 0.675280(all), eps: 0.200000
DR: 0.639179(50), 0.675258(all), eps: 0.200000
DR: 0.709548(50), 0.675291(all), eps: 0.200000
DR: 0.726361(50), 0.675322(all), eps: 0.200000
DR: 0.669279(50), 0.675332(all), eps: 0.200000
DR: 0.659442(50), 0.675322(all), eps: 0.200000
DR: 0.685827(50), 0.675332(all), eps: 0.200000
DR: 0.649037(50), 0.675316(all), eps: 0.200000
DR: 0.613730(50), 0.675289(all), eps: 0.200000
DR: 0.633742(50), 0.675263(all), eps: 0.200000
DR: 0.621419(50), 0.675234(all), eps: 0.200000
DR: 0.696017(50), 0.675247(all), eps: 0.200000
DR: 0.569870(50), 0.675191(all), eps: 0.200000
DR: 0.696999(50), 0.675204(all), eps: 0.200000
DR: 0.681410(

DR: 0.662971(50), 0.679191(all), eps: 0.200000
DR: 0.633056(50), 0.679166(all), eps: 0.200000
DR: 0.676464(50), 0.679178(all), eps: 0.200000
DR: 0.656278(50), 0.679165(all), eps: 0.200000
DR: 0.740876(50), 0.679212(all), eps: 0.200000
DR: 0.778431(50), 0.679266(all), eps: 0.200000
DR: 0.677891(50), 0.679259(all), eps: 0.200000
DR: 0.758248(50), 0.679302(all), eps: 0.200000
DR: 0.826597(50), 0.679395(all), eps: 0.200000
DR: 0.702893(50), 0.679408(all), eps: 0.200000
DR: 0.727669(50), 0.679448(all), eps: 0.200000
DR: 0.600428(50), 0.679405(all), eps: 0.200000
DR: 0.646971(50), 0.679403(all), eps: 0.200000
DR: 0.629339(50), 0.679376(all), eps: 0.200000
DR: 0.752398(50), 0.679423(all), eps: 0.200000
DR: 0.728059(50), 0.679449(all), eps: 0.200000
DR: 0.621425(50), 0.679433(all), eps: 0.200000
DR: 0.655545(50), 0.679421(all), eps: 0.200000
DR: 0.649696(50), 0.679399(all), eps: 0.200000
DR: 0.599152(50), 0.679356(all), eps: 0.200000
DR: 0.751850(50), 0.679409(all), eps: 0.200000
DR: 0.721390(

In [None]:
offline_model.save_parameters(temporary_model+".best")