-
Notifications
You must be signed in to change notification settings - Fork 13
/
agentServer.py
115 lines (95 loc) · 2.92 KB
/
agentServer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import cherrypy
import argparse
from ws4py.server.cherrypyserver import WebSocketPlugin, WebSocketTool
from ws4py.websocket import WebSocket
import msgpack
import io
from PIL import Image
import threading
from .tensor import *
import numpy as np
from chainer import Variable, FunctionSet, optimizers
from PIL import ImageOps
class Root(object):
@cherrypy.expose
def index(self):
return 'some HTML with a websocket javascript connection'
@cherrypy.expose
def ws(self):
# you can access the class instance through
handler = cherrypy.request.ws_handler
workout= None
depth_image=None
Depth_dim=32*32
def DepthImage():
return Tensor(value=np.asarray(depth_image).reshape(Depth_dim))
def Concat(y,x=None):
if x is None:
x = Tensor.context
#print x.value.mean()
#print y.value.sum()
dat = np.r_[x.value,y.value];
#print dat.sum();
x = Variable(dat, volatile=True)
t = ChainerTensor(x )
t.use()
return t
class AgentServer(WebSocket):
agent_initialized = False
cycle_counter = 0
thread_event = threading.Event()
trainer = None
mode='none'
reward=None
log_file = 'log_reward.log'
reward_sum = 0
def received_message(self, m):
global depth_image
payload = m.data
dat = msgpack.unpackb(payload)
screen = Image.open(io.BytesIO(bytearray(dat['image'])))
x = screen
reward = dat['reward']
end_episode = dat['endEpisode']
depth_image = ImageOps.grayscale(Image.open(io.BytesIO(bytearray(dat['depth']))))
if not self.agent_initialized:
self.agent_initialized = True
AgentServer.mode='start'
action = workout(x)
self.send(str(action))
with open(self.log_file, 'w') as the_file:
the_file.write('cycle, episode_reward_sum \n')
else:
self.thread_event.wait()
self.cycle_counter += 1
self.reward_sum += reward
if end_episode:
AgentServer.mode='end'
workout(x)
#self.agent.agent_end(reward)
AgentServer.mode='start'
#action = self.agent.agent_start(image) # TODO
action = workout(x)
self.send(str(action))
with open(self.log_file, 'a') as the_file:
the_file.write(str(self.cycle_counter) +
',' + str(self.reward_sum) + '\n')
self.reward_sum = 0
else:
#action, rl_action, eps, Q_now, obs_array, returnAction = self.agent.agent_step(reward, image)
#self.agent.agent_step_after(reward, image, rl_action, eps, Q_now, obs_array, returnAction)
AgentServer.mode='step'
ag,action, eps, Q_now, obs_array = workout(x)
self.send(str(action))
ag.step_after(reward, action, eps, Q_now, obs_array)
self.thread_event.set()
def StartAgent(trainer=None,port=8765):
global workout
workout = trainer
cherrypy.config.update({'server.socket_port': port})
WebSocketPlugin(cherrypy.engine).subscribe()
cherrypy.tools.websocket = WebSocketTool()
cherrypy.config.update({'engine.autoreload.on': False})
config = {'/ws': {'tools.websocket.on': True,
'tools.websocket.handler_cls': AgentServer}}
cherrypy.quickstart(Root(), '/', config)