In [82]:
import tensorflow as tf
import numpy as np

In [99]:
class R2D2(tf.keras.Model):
    def __init__(self, num_of_actions):
        super(R2D2, self).__init__()
        self.num_of_actions = num_of_actions

        self.cnn1 = tf.keras.layers.TimeDistributed(
            tf.keras.layers.Conv2D(
                32, kernel_size=(8, 8), strides=(4, 4),
                activation=tf.keras.activations.relu, padding='same'))
        self.cnn2 = tf.keras.layers.TimeDistributed(
            tf.keras.layers.Conv2D(
                64, kernel_size=(4, 4), strides=(2, 2),
                activation=tf.keras.activations.relu, padding='same'))
        self.cnn3 = tf.keras.layers.TimeDistributed(
            tf.keras.layers.Conv2D(
                64, kernel_size=(3, 3), strides=(1, 1),
                activation=tf.keras.activations.relu, padding='same'))
        self.flatten = tf.keras.layers.TimeDistributed(
            tf.keras.layers.Flatten())
        self.dense = tf.keras.layers.TimeDistributed(
            tf.keras.layers.Dense(
                512, activation=tf.keras.activations.relu))
        
        self.lstm = tf.keras.layers.LSTM(512)

        self.dense11 = tf.keras.layers.Dense(
            512, activation=tf.keras.activations.relu)
        self.dense12 = tf.keras.layers.Dense(1)
        
        self.dense21 = tf.keras.layers.Dense(
            512, activation=tf.keras.activations.relu)
        self.dense22 = tf.keras.layers.Dense(18)
    
    def call(self, x, a, re, ri, beta):
        input = x
        x = self.cnn1(input)
        x = self.cnn2(x)
        x = self.cnn3(x)
        x = self.flatten(x)
        x = self.dense(x)

        a = tf.one_hot(a, self.num_of_actions)
        print(x.shape, a.shape, re.shape, ri.shape, beta.shape)
        concat = tf.concat([x, a, re, ri, beta], axis=2)
        
        x = self.lstm(concat)

        x1 = self.dense11(x)
        x1 = self.dense12(x1)

        x2 = self.dense21(x)
        x2 = self.dense22(x2)

        output = x1 + x2 - tf.math.reduce_mean(x2, axis=1)
        return output

In [100]:
num_of_actions = 4
batch_size = 1
time_sequence = 32
x = tf.random.normal((batch_size, time_sequence, 50, 50, 1))
a = tf.convert_to_tensor(np.random.choice(num_of_actions, size=(batch_size, time_sequence)))
re = tf.random.normal((batch_size, time_sequence, 1))
ri = tf.random.normal((batch_size, time_sequence, 1))
beta = tf.random.normal((batch_size, time_sequence, 1))

In [101]:
agent = R2D2(num_of_actions)

In [102]:
agent(x, a, re, ri, beta)

(1, 32, 512) (1, 32, 4) (1, 32, 1) (1, 32, 1) (1, 32, 1)


<tf.Tensor: shape=(1, 18), dtype=float32, numpy=
array([[ 0.02619118,  0.0324901 , -0.10092699,  0.05801334, -0.01880517,
        -0.01457014,  0.01139851, -0.04672857, -0.05129401, -0.02857853,
        -0.09235153, -0.11016186, -0.0274087 , -0.02625233, -0.10065489,
         0.00820833, -0.00032712, -0.0110217 ]], dtype=float32)>

In [4]:
# from moving_average import MovingAverage
from episodic_novelty import EpisodicNovelty
from life_long_novelty import LifeLongNovelty

In [5]:
num_of_actions = 4

episodic_novelty_module = EpisodicNovelty(num_of_actions)
life_long_novelty_module = LifeLongNovelty()

In [8]:
import tensorflow as tf
num_of_actions = 4
batch_size = 1
time_sequence = 32
obs = tf.random.normal((batch_size, time_sequence, 50, 50, 1))
next_obs = tf.random.normal((batch_size, time_sequence, 50, 50, 1))

L = 5
episodic_reward = episodic_novelty_module(obs, next_obs)
modulator = life_long_novelty_module(obs)
intrinsic_reward = episodic_reward * tf.minimum(tf.maximum(modulator, 1), L)

In [9]:
intrinsic_reward

<tf.Tensor: shape=(1, 32, 128), dtype=float32, numpy=
array([[[0.06521409, 0.03260705, 0.03260705, ..., 0.03260705,
         0.03260705, 0.03260705],
        [0.03260705, 0.06521409, 0.06521409, ..., 0.03260705,
         0.06521408, 0.06521409],
        [0.06521409, 0.03260705, 0.06521409, ..., 0.06521409,
         0.03260705, 0.06521409],
        ...,
        [0.06521409, 0.06521409, 0.06521409, ..., 0.03260705,
         0.03260705, 0.06521409],
        [0.03260705, 0.03260705, 0.0652141 , ..., 0.06521409,
         0.03260705, 0.03260705],
        [0.03260705, 0.06521409, 0.03260705, ..., 0.06521409,
         0.03260705, 0.03260705]]], dtype=float32)>