-
Notifications
You must be signed in to change notification settings - Fork 1
/
drive.py
75 lines (54 loc) · 2.49 KB
/
drive.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import random
from collections import deque
import numpy as np
import cv2
import time
import tensorflow as tf
import keras.backend.tensorflow_backend as backend
from keras.models import load_model
from model import CarEnv, MEMORY_FRACTION
MODEL_PATH = 'models/Xception__-118.00max_-179.10avg_-250.00min__1566603992.model'
if __name__ == '__main__':
# Memory fraction
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=MEMORY_FRACTION)
backend.set_session(tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)))
# Load the model
model = load_model(MODEL_PATH)
# Create environment
env = CarEnv()
# For agent speed measurements - keeps last 60 frametimes
fps_counter = deque(maxlen=60)
# Initialize predictions - first prediction takes longer as of initialization that has to be done
# It's better to do a first prediction then before we start iterating over episode steps
model.predict(np.ones((1, env.im_height, env.im_width, 3)))
# Loop over episodes
while True:
print('Restarting episode')
# Reset environment and get initial state
current_state = env.reset()
env.collision_hist = []
done = False
# Loop over steps
while True:
# For FPS counter
step_start = time.time()
# Show current frame
cv2.imshow(f'Agent - preview', current_state)
cv2.waitKey(1)
# Predict an action based on current observation space
qs = model.predict(np.array(current_state).reshape(-1, *current_state.shape)/255)[0]
action = np.argmax(qs)
# Step environment (additional flag informs environment to not break an episode by time limit)
new_state, reward, done, _ = env.step(action)
# Set current step for next loop iteration
current_state = new_state
# If done - agent crashed, break an episode
if done:
break
# Measure step time, append to a deque, then print mean FPS for last 60 frames, q values and taken action
frame_time = time.time() - step_start
fps_counter.append(frame_time)
print(f'Agent: {len(fps_counter)/sum(fps_counter):>4.1f} FPS | Action: [{qs[0]:>5.2f}, {qs[1]:>5.2f}, {qs[2]:>5.2f}] {action}')
# Destroy an actor at end of episode
for actor in env.actor_list:
actor.destroy()