Skip to content

Commit e9550bd

Browse files
misc refactoring
1 parent 8209e77 commit e9550bd

File tree

3 files changed

+25
-23
lines changed

3 files changed

+25
-23
lines changed

Agent/DQNEnsembleAgent.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,23 @@
33
import tensorflow.keras as keras
44
import tensorflow.keras.layers as layers
55
import tensorflow as tf
6+
from Agent.MaskedSoftmax import MaskedSoftmax
67

78
def combineModels(models, combiner):
89
shape = models[0].layers[0].input_shape[0][1:]
910
inputs = layers.Input(shape=shape)
1011
actionsMask = layers.Input(shape=(4, ))
11-
res = layers.Lambda(combiner)([actionsMask] + [ x(inputs) for x in models ])
12-
return keras.Model(inputs=[inputs, actionsMask], outputs=res)
13-
14-
def maskedSoftmax(mask, inputs):
15-
mask = tf.where(tf.equal(mask, 1))
16-
return [
17-
tf.sparse.to_dense(
18-
tf.sparse.softmax(
19-
tf.sparse.SparseTensor(
20-
indices=mask,
21-
values=tf.gather_nd(x, mask),
22-
dense_shape=tf.shape(x, out_type=tf.int64)
23-
)
24-
)
25-
) for x in inputs
26-
]
2712

28-
def multiplyOutputs(inputs):
29-
outputs = maskedSoftmax(inputs[0], inputs[1:])
13+
predictions = [ layers.Reshape((1, -1))(
14+
MaskedSoftmax()( x(inputs), actionsMask )
15+
) for x in models ]
3016

31-
res = 1 + outputs[0]
32-
for x in outputs[1:]:
33-
res = tf.math.multiply(res, 1 + x)
34-
return res
17+
res = layers.Lambda(combiner)( layers.Concatenate(axis=1)(predictions) )
18+
return keras.Model(inputs=[inputs, actionsMask], outputs=res)
19+
20+
@tf.function
21+
def multiplyOutputs(outputs):
22+
return tf.math.reduce_prod(1 + outputs, axis=1)
3523

3624
ENSEMBLE_MODE = {
3725
'multiply': multiplyOutputs

Agent/MaskedSoftmax.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import tensorflow as tf
2+
3+
class MaskedSoftmax(tf.keras.layers.Layer):
4+
def call(self, inputLayer, mask):
5+
mask = tf.where(tf.equal(mask, 1))
6+
return tf.sparse.to_dense(
7+
tf.sparse.softmax(
8+
tf.sparse.SparseTensor(
9+
indices=mask,
10+
values=tf.gather_nd(inputLayer, mask),
11+
dense_shape=tf.shape(inputLayer, out_type=tf.int64)
12+
)
13+
)
14+
)

view_maze.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# -*- coding: utf-8 -*-
33
import tensorflow as tf
44
import os
5-
from Agent.DQNEnsembleAgent import DQNEnsembleAgent
65
# limit GPU usage
76
gpus = tf.config.experimental.list_physical_devices('GPU')
87
tf.config.experimental.set_virtual_device_configuration(
@@ -15,6 +14,7 @@
1514
import pygame.locals as G
1615
import random
1716
from Agent.DQNAgent import DQNAgent
17+
from Agent.DQNEnsembleAgent import DQNEnsembleAgent
1818
import glob
1919
from collections import namedtuple
2020
from model import createModel

0 commit comments

Comments
 (0)