In [3]:
import tensorflow as tf
import numpy as np

In [121]:
ENV_SIZE = 16
INITIAL_POSITION = 1
FOOD_POSITION = ENV_SIZE/2
LEARNING_RATE = 0.01


tf.reset_default_graph()
Psi = tf.Variable(initial_value=INITIAL_POSITION, trainable=False, name='external')  # external states
S = tf.Variable(initial_value=1.0, trainable=False, name='sensory')  # sensory states
A = tf.Variable(initial_value=1.0, trainable=True, name='action')  # action states
R = tf.Variable(initial_value=[1/16]*16, trainable=True, name='internal')  # internal states


def environmental_dynamics(psi, a):
    # How position changes given action state
    return (psi + a) % ENV_SIZE


def sensory_dynamics(psi):
    # P(s | psi)
    # Sensory states given external states
    k = 4**(-1/16)
    omega = np.log(4)/16
    return k * tf.exp(-omega * tf.abs(psi - FOOD_POSITION))


def model_encoding(R):
    # P(psi | r)
    # External state given internal state
    return tf.nn.softmax(R)


def generative_density(R, a):
    return environmental_dynamics(sensory_dynamics(model_encoding(R)), a)


def variational_density(R):
    """
    Agent belief about the external states (i.e. its current position in the 
    world) as encoded in internal states
    """
    return tf.nn.softmax(R)


def KL(a, b):
    """
    Kullback-Leibler divergence between densities a and b
    """
    return tf.reduce_mean(
        -tf.nn.softmax_cross_entropy_with_logits(
            logits=a, 
            labels=b
        )
    )


def free_energy(R, s, a):
    divergence = KL(variational_density(R), generative_density(R, a))
    surprisal = tf.log(tf.multiply(s, a))
    return divergence - surprisal

cost = free_energy(R, S, A)

In [122]:
# Initializing the variables
init = tf.global_variables_initializer()

# Initialize optimizer
optimizer = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(cost)

# Launch the graph
with tf.Session() as sess:
    sess.run(init)
    for var in tf.global_variables():
        print('{name} = {value}'.format(name=var.name, value=var.eval(session=sess)))
    
    print(generative_density(R, A).eval())
    print(free_energy(R, S, A).eval())
    
    summary_writer = tf.summary.FileWriter('logs', graph=tf.get_default_graph())
    summary_writer.add_graph(tf.get_default_graph())

    for epoch in range(10):
        sess.run(optimizer)
        for var in tf.global_variables():
            print('{name} = {value}'.format(name=var.name, value=var.eval(session=sess)))

external:0 = 1
sensory:0 = 1.0
action:0 = 1.0
internal:0 = [ 0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625
  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625]
[ 1.46099162  1.46099162  1.46099162  1.46099162  1.46099162  1.46099162
  1.46099162  1.46099162  1.46099162  1.46099162  1.46099162  1.46099162
  1.46099162  1.46099162  1.46099162  1.46099162]
-64.8117


external:0 = 1
sensory:0 = 1.0
action:0 = 1.0099999904632568
internal:0 = [ 0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625
  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625]
external:0 = 1
sensory:0 = 1.0
action:0 = 1.0199010372161865
internal:0 = [ 0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625
  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625]
external:0 = 1
sensory:0 = 1.0
action:0 = 1.0297058820724487
internal:0 = [ 0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625
  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625]
external:0 = 1
sensory:0 = 1.0
action:0 = 1.0394173860549927
internal:0 = [ 0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625
  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625]
external:0 = 1
sensory:0 = 1.0
action:0 = 1.0490381717681885
internal:0 = [ 0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625  0.0625
  0.0625  0.0625  0.0625  0.0625 