In [1]:
import nengo
import matplotlib.pyplot as plt
import numpy
import gym
import math
%matplotlib inline

# Network details

## Goal Directed/Model-Based

1. LatPFC Maintains a repository representing the environment. Repository of all states and actions.
   Receives cues from the environment and maintains it
   Receives SPE from preSMA and/or pIPS
   
2. LatOFC receives the repository and State Prediction Error from LatPFC and maintains a map of the environment. Something like an internal environment. Updates trainsition probabilities by SPE and updates reward values from dopamine errors.

3. mOFC uses the internal environment from map from LatOFC to do state space search. Receives dopamine errors and passes it to LatOFC to update the internal environment. Passes expected value of the next state after DFS to vmPFC.

4. vmPFC maintains the expected value and passes it to the striatum. Updates expected value for that state from dopamne errors from the striatum and passes it to mOFC.

5. DMS represents expected value. Modulated by Dopamine. Passes it back to vmPFC

6. preSMA/pIPS compare expected state and actual and generates SPE. Passes SPE to latPFC.

## Habitual/Model-Free

1. latPFC maintains the repository of states and passes cue and the repository to dlPFC.

2. dlPFC maintains q_values. Constructs them using the repository from latPFC and dopamine errors from DLS. Uses the cue to decide which value to pass to the DLS.

3. DLS represents expected q_value. Modulated by Dopamine. Passes it back to dlPFC

In [2]:
env = gym.make("FrozenLake-v0")
state = env.reset()

[2017-09-05 17:46:17,608] Making new env: FrozenLake-v0


In [3]:
def to_discrete(vector):
    discrete = vector[0]*4 + vector[1]
    return discrete

def to_vector(discrete):
    vector = numpy.zeros((1, 2))
    vector[1] = discrete%4
    vector[0] = int(discrete/4)
    return vector


In [4]:
class latPFC:
    
    def __init__(self):
        self.state_repo = list()
        self.action_repo = list()
        
    def build_state_repo(self, env):
        for i in range(env.observation_space.n/4):
            for j in range(env.observation_space.n/4):
                self.state_repo.append([i, j])
        self.state_repo = numpy.array(self.state_repo)

    def bulid_action_repo(self, env):
        for i in range(env.action_space.n/4):
            for j in range(env.action_space.n/4):
                self.action_repo.append([i, j])
        self.action_repo = numpy.array(self.action_repo)
    
    def get_cue(self, cue):
        self.state = env.observation_space.sample()
    
    def calc_spe(self, exp_state, actual_state):
        if exp_state == actual_state:
            self.spe = 0
        else:
            self.spe = 1
        return spe
    

In [5]:
#spe is the state prediction error. 1 or 0. 1 is no error
#states is matrix of value for each state
#actions is actions in order "left, down, right, up"

class latOFC:
    def __init__(self, state_repo, action_repo, spe):
        self.spe = spe
        self.spe_sum = 0
        self.count = 0
        self.states = numpy.zeros((len(state_repo)/4, len(state_repo)/4))
        self.states[0][1] = 1
        self.states[0][2] = 1
        self.actions = numpy.array([[0, -1], [1, 0], [0, 1], [-1, 0]])
        self.transition_prob = 1.0

        
    def update_transition_prob(self, spe):
        self.count += 1
        self.spe_sum += spe
        self.transition_prob = float(self.spe_sum)/self.count
    
    @classmethod
    def transition(self, state, action, transition_prob):
        if numpy.random.rand() < transition_prob:
            temp = numpy.add(state, action)
            if temp[0] > 3 or temp[1] > 3 or temp[0] < 0 or temp[1] < 0:
                new_state = state
            else:
                new_state = temp
            
        else:
            rand_action = self.actions[env.action_space.sample()]
            temp = self.states[numpy.add(state, rand_action)]
            if temp[0] > 3 or temp[1] > 3 or temp[1] < 0 or temp[1] < 0:
                new_state = state
            else:
                new_state = temp
        return new_state                            

       

In [59]:
class mOFC:
    #transition_prob, states and actions received from latPFC
    def __init__(self, transition_prob, states, actions):
        self.transition_prob = transition_prob
        self.states = states
        print states
        self.actions = actions

    
    #calculate value of current state using depth limited search
    @classmethod
    def dls(self, curr_state, curr_depth, max_depth):
        value_list = list()
        if curr_depth == max_depth:
            return self.states[curr_state[0]][curr_state[1]]
        else:
            for i in self.actions:
                next_state = latOFC.transition(curr_state, i, self.transition_prob)
                #if list(next_state) in visited:
                 #   continue
                value = self.states[next_state[0]][next_state[1]] + 0.9*self.dls(next_state, curr_depth+1, max_depth)
                value_list.append(value)
            if value_list:
                average_value = numpy.mean(numpy.array(value_list))
        return average_value
    
    def value_of_action(self, curr_state):
        value_list = list()
        for i in self.actions:
            next_state = latOFC.transition(curr_state, i, self.transition_prob)
            value_list.append(self.dls(next_state, 0, 3, self.visited))
        return value_list
        

In [80]:
class vmPFC(mOFC):
    def __init__(self, states, actions, curr_state, transition_prob):
        mOFC.__init__(self, transition_prob, states, actions)
        self.curr_state = curr_state
    def value_of_actions(self):
        value_list = list()
        for i in self.actions:
            next_state = latOFC.transition(self.curr_state, i, self.transition_prob)
            value_list.append(self.dls(next_state, 0, 3))

In [70]:
a = latPFC()

In [71]:
a.build_state_repo(env)
a.bulid_action_repo(env)

In [72]:
b = latOFC(a.state_repo, a.action_repo, 0)

In [77]:
c = mOFC(1.0, b.states, b.actions)

[[ 0.  1.  1.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]]


In [81]:
d = vmPFC(c.states, c.actions, numpy.array([1, 0]), 1.0)

[[ 0.  1.  1.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]]


In [82]:
print d.value_of_actions()

AttributeError: class vmPFC has no attribute 'actions'

In [None]:
print b.transition(numpy.array([1, 1]), numpy.array([0, -1]), 1.0)