# Introduction to successor representations

BLG JC 06.06.2025

Daniel Kornai & Puria Radmard

## Notation

Assuming a basic awareness of MDPs and RL

$i,j\in\mathcal{N}$ - non-terminal states

$k\in\mathcal{T}$ - terminal states

$Q_{ij}$ - probabiltiy of moving from state $i$ to non-terminal state $j$

$s_{ik}$ - probabiltiy of moving from state $i$ to terminal state $k$

$z_k$ - reward from terminating at state $k$, with mean $\bar{z}_k$

## Expected returns

Vector of *immediate* mean return: $\boldsymbol{h}: h_i = \sum_{k\in\mathcal{T}} s_{ik} \bar{z}_k$ - i.e. get reward by moving from $i$ directly to $k$

Vector of *overall* mean return with no discount factor: $\bar{\boldsymbol{r}} = \boldsymbol{h} + Q\boldsymbol{h} + Q^2\boldsymbol{h} ... = [I - Q]^{-1}\boldsymbol{h}$

Vector of *overall* mean return with discount factor $\gamma$: $\bar{\boldsymbol{r}} = \boldsymbol{h} + \gamma Q\boldsymbol{h} + \gamma ^2Q^2\boldsymbol{h} ... = [I - \gamma Q]^{-1}\boldsymbol{h}$

Discount: $\bar{\boldsymbol{r}} = \boldsymbol{h} + \gamma Q \bar{\boldsymbol{r}}$

## Value approximation

We represent each state $i$ with vector $\boldsymbol{x}_i$

Want to learn an approximation $\hat{r}(i;\boldsymbol{w}) = \boldsymbol{x}_i^\intercal \boldsymbol{w} \approx \bar{r}_i$

i.e. $\hat{\boldsymbol{r}}(\boldsymbol{w}) = X^\intercal \boldsymbol{w}$, where $X$ is the stacked representation vector

## Temporal difference learning of value approximation

In this simple setup where we do not get a reward for steps between non-terminal states - reward only comes at terminal states - we would like the following consistency to hold:

$\hat{r}(i;\boldsymbol{w}) \approx \sum_{k} s_{ik} \bar{z}_k + \gamma \sum_{j} Q_{ij} \hat{r}(j;\boldsymbol{w})$

Assuming we have a good state representation set $X$, we want to learn $\boldsymbol{w}$ from experience. 

Say we take a step from state $i$ to state $j$ - both non-terminal. For these states, we can move towards this approximation by, on each step, minimising the temporal error:

$\epsilon = | \hat{r}(i;\boldsymbol{w}) - \gamma\hat{r}(j;\boldsymbol{w}) |^2 = (\boldsymbol{x}_i^\intercal\boldsymbol{w} - \gamma\hat{r}(j;\boldsymbol{w}))^2$

e.g. by doing S(semi-)GD:

$\nabla_{\boldsymbol{w}}\epsilon = 2 (\boldsymbol{x}_i^\intercal\boldsymbol{w} - \gamma\hat{r}(j;\boldsymbol{w})) \boldsymbol{x}_i$

with some learning rate $\alpha$:

$\boldsymbol{w} \longleftarrow \boldsymbol{w} - \alpha (\boldsymbol{x}_i^\intercal\boldsymbol{w} - \gamma\boldsymbol{x}_j^\intercal\boldsymbol{w}) \boldsymbol{x}_i$

and if it was a terminal state, where we received an immediate reward of $z$, then:

$\epsilon = (\boldsymbol{x}_i^\intercal\boldsymbol{w} - z)^2 \therefore \boldsymbol{w} \longleftarrow \boldsymbol{w} - \alpha (\boldsymbol{x}_i^\intercal\boldsymbol{w} - z) \boldsymbol{x}_i$

## But what should the representation set $X$ look like?

We want the linear approximation $\boldsymbol{r}$ to include information about the successors of the current state. Therefore, "a good representation for a state would be one that resembles the representations of its successors"

For example, we can define $X_{ij} = [\boldsymbol{x}_{i}]_j$ to be the *discounted expected occupancy of state j after starting at state i*

That is: $X_{ij} = \sum_{\tau=0}^\infty \gamma^\tau p(s_{t + \tau} = j | s_t = i)$

So: $X = I + \gamma Q + \gamma^2 Q^2 + ... = [I - \gamma Q]^{-1}$

Then, we immediately have the optimal weight matrix for this: $\boldsymbol{w}^* = \boldsymbol{h}$

Here, $\boldsymbol{x}_i$ is called the *successor representation* of state $i$.

Sneakily, we have combined the terminal and non-terminal states, and have combined $Q$ and $s$ - but the key idea still stands 

## TD learning of successor representation

We have a similar consistency which we can approximate using gradient updates:

$X \approx I + \gamma QX$

When we observe a step from $i$ to $j$, we can update:

$\forall k: \quad X_{ik} \longleftarrow X_{ik} - \alpha [X_{ik} - (\delta_{jk} + \gamma X_{jk})]$ 

This update follows intuitively from the original definition of $X_{ij}$ - we increment $X_{ik}$ towards a $\delta_{ik} + \gamma X_{jk}$, for all $k$

## Joint TD learning

We can learn $X$ and $\boldsymbol{w}$ jointly - using the updates defined above

## State values to action values

We now consider the setting where in each state $i$ we have access to a fixed, discrete set of actions $a\in\mathcal{A}$

$\pi(a|i)$ = the probability of choosing action $a$ in state $i$

$\hat r_\pi(j; \boldsymbol{w})$ = linear state value estimator assuming we follow policy $\pi$

$h_i(a)$ = expected termination reward from state $i$, conditional on taking action $a$

We are faced with the challenge of learning *action values*, a.k.a. *Q-values* $\hat{q}_\pi(i;\boldsymbol{w}) \approx h_i(a) + \gamma \sum_j Q_{ij}(a) \hat r_\pi(j; \boldsymbol{w})$

It ideally abides by the similar consistency: $\hat{q}_\pi(i;\boldsymbol{w}) \approx \sum_{a} \pi(a | i) \left[ \sum_{k} s_{ik}(a) \bar{z}_k + \gamma \sum_{j} Q_{ij}(a) \hat{q}_\pi(j;\boldsymbol{w}) \right]$

The equivalent successor representation is now the tensor: $X^\pi_{iaj} = \sum_{\tau=0}^\infty \gamma^\tau \sum_{j\in\mathcal{S}} p(s_{t + \tau} = j | s_t = i, a_t = a, a_{>t}\sim\pi)$

It's definition is now the "*discounted expected occupancy of state j after starting at state i and taking action a, then following policy $\pi$ thereafter*"

TD learning of both the SR and the return approximation looks the same, except rules also involve indexing over the action selected

## Key benefit of learning SR

Given a well-learned successor representation, we effectively have a model of the world and its transition probabilities, independent of the reward structure (sort of - reward structure will impact learned policy, which in turn impacts the action-conditional SR)

So, if the reward structure changes, the agent can retain its state representations (or bootstrap learning from pretrained SR), and just (or mainly) learn the q-value linear approximation

## Demonstration

In [42]:
from typing import List, Tuple
import numpy as np

class GridWorld:

    def __init__(
        self,
        height: int,
        width: int,
        reward_cells: List[Tuple[int]],
        prob_correct = 0.5,
        term_reward = 1.0,
    ):

        self.height = height
        self.width = width

        self.prob_correct = prob_correct
        self.term_reward = term_reward

        self.current_state = None

        self.reward_cells = set(reward_cells)
        for i, j in reward_cells:
            assert 0<=i<self.height
            assert 0<=j<self.width

    def init(self) -> None:
        #while (self.current_state in self.reward_cells) or (self.current_state is None):
        #    self.current_state = (np.random.randint(self.height), np.random.randint(self.width))
        self.current_state = (
            np.random.choice([0, self.height - 1]), 
            np.random.choice([0, self.width - 1]),
        )

    def det_step_inner(self, action: int) -> Tuple[int, int]:
        new_state = [self.current_state[0], self.current_state[1]]
        if action == 0: # N
            new_state[0] = min(new_state[0] + 1, self.height - 1)
        if action == 1: # E
            new_state[1] = min(new_state[1] + 1, self.width - 1)
        if action == 2: # W
            new_state[1] = max(new_state[1] - 1, 0)
        if action == 3: # S
            new_state[0] = max(new_state[0] - 1, 0)
        return tuple(new_state)

    def step(self, action: int) -> Tuple[Tuple[int], Tuple[int], float, bool]:
        """
        Deterministic step self.prob_correct of the time
        0, 1, 2, 3 = N E W S
        """
        u = np.random.random([])
        if u <= self.prob_correct:
            new_state = self.det_step_inner(action)
        else:
            new_state = self.det_step_inner(np.random.randint(4))
        old_state = tuple(self.current_state)
        if new_state in self.reward_cells:
            reward = self.term_reward
            terminal_flag = True
            self.init()
        else:
            reward = 0.0
            terminal_flag = False
            self.current_state = new_state
        return old_state, new_state, reward, terminal_flag



In [43]:
import numpy as np

from abc import ABC, abstractmethod

from typing import Tuple, List, Optional

from matplotlib.pyplot import Axes

class GridAgent:

    def __init__(
        self,
        height: int,
        width: int,
        discount_factor: float,
        learning_rate: float
    ):
        self.width = width
        self.height = height

        assert 0.0 <= discount_factor <= 1.0
        self.discount_factor = discount_factor

        assert learning_rate > 0
        self.learning_rate = learning_rate

    @abstractmethod
    def take_action(self, state: Tuple[int]) -> int:
        raise NotImplementedError

    @abstractmethod
    def observe_transition(
        self, 
        old_state: Tuple[int],
        new_state: Tuple[int],
        action: int,
        transition_reward: float,
        terminal: bool
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def visualise(self, *axes: Axes):
        raise NotImplementedError
    


class EpsilonGreedyQAgent(GridAgent):
    
    def __init__(self, epsilon: float, height: int, width: int, discount_factor: float, learning_rate: float):
        super().__init__(height, width, discount_factor, learning_rate)
        self.epsilon = epsilon
        self.q_values = np.zeros([height, width, 4])

    def get_q_values(self, state: Tuple[int]) -> int:
        i, j = state
        return self.q_values[i,j]

    def change_q_value(self, state: Tuple[int], action: int, change: float) -> None:
        i, j = state
        self.q_values[i,j,action] += self.learning_rate * change

    def take_action(self, state: Tuple[int]) -> int:
        u = np.random.random([])
        if u <= self.epsilon:
            return np.random.randint(4)
        else:
            return np.argmax(self.get_q_values(state))

    def observe_transition(self, old_state: Tuple[int], new_state: Tuple[int], action: int, transition_reward: float, terminal: bool) -> None:
        bootstrapped_new_value = (
            transition_reward if terminal else
            transition_reward + self.discount_factor * self.get_q_values(new_state).max()
        )
        td_error = bootstrapped_new_value - self.get_q_values(old_state)[action]
        self.change_q_value(old_state, action, td_error)
    
    def visualise(self, axes: List[Axes]):
        assert len(axes) == 4
        for i, ax in enumerate(axes):
            ax.imshow(self.q_values[:,:,i], origin = 'lower')
            ax.set_title('NEWS'[i])

        
class SuccessorRepresentationEpsilonGreedyQAgent(GridAgent):
    
    def __init__(self, epsilon: float, height: int, width: int, discount_factor: float, learning_rate: float, learning_rate_v: float):
        super().__init__(height, width, discount_factor, learning_rate)
        self.epsilon = epsilon
        self.sf_M = np.zeros([height, width, 4, height, width])    # (s, a) -> s'
        self.w = np.random.randn(height, width)
        self.learning_rate_v = learning_rate_v

    def take_action(self, state: Tuple[int]) -> int:
        u = np.random.random([])
        if u <= self.epsilon:
            return np.random.randint(4)
        else:
            return np.argmax(self.get_q_values(state))

    def get_q_values(self, state: Tuple[int]) -> int:
        i, j = state
        from_state_sf_M = self.sf_M[i,j]    # [4, h, w]
        return (from_state_sf_M * self.w[None]).sum(-1).sum(-1)   # [4]

    def change_q_value(self, state: Tuple[int], action: int, change: float) -> None:
        i, j = state
        self.w += self.learning_rate_v * change * self.sf_M[i,j,action]
        
    def observe_transition(self, old_state: Tuple[int], new_state: Tuple[int], action: int, transition_reward: float, terminal: bool) -> None:

        # Update M[s, a, s'] for s
        bootstrapped_new_sr = np.zeros([self.height, self.width]) # I[new state, j] + \gamma * M[new state, j] for all j
        bootstrapped_new_sr[new_state[0], new_state[1]] = 1.0
        if not terminal:
            bootstrapped_new_sr = bootstrapped_new_sr + self.discount_factor * self.sf_M[:, :, action, new_state[0], new_state[1]]
        td_error = bootstrapped_new_sr - self.sf_M[:, :, action, old_state[0], old_state[1]]
        self.sf_M[old_state[0], old_state[1], action] += self.learning_rate * td_error

        # Update Q[s, a] for s, a, by changing w
        bootstrapped_new_value = (
            transition_reward if terminal else
            transition_reward + self.discount_factor * self.get_q_values(new_state).max()
        )
        td_error = bootstrapped_new_value - self.get_q_values(old_state)[action]
        self.change_q_value(old_state, action, td_error)

    def visualise(self, axes: List[Axes], source_state: Optional[Tuple[int]]):
        assert len(axes) >= 5
        for i, ax in enumerate(axes[:4]):
            ax.imshow(self.sf_M[source_state[0], source_state[1], i], origin = 'lower')
            ax.set_title('NEWS'[i] + f' from {source_state}')
        axes[-1].set_title('w')
        axes[-1].imshow(self.w, origin = 'lower')


In [44]:
from tqdm import tqdm

from matplotlib import pyplot as plt

height = 5
width = 5


world = GridWorld(
    height = height,
    width = width,
    reward_cells = [(2, 2)],
    prob_correct = 0.5,
    term_reward = 10.0
)

agent = EpsilonGreedyQAgent(
    epsilon = 1.0,
    height = height,
    width = width,
    discount_factor = 0.99,
    learning_rate = 0.01,
)

world.init()

fig, axes = plt.subplots(2, 4, figsize = (12, 8))

num_steps = 5_000

for ns in tqdm(range(num_steps)):

    action = agent.take_action(world.current_state)

    old_state, new_state, reward, terminal_flag = world.step(action)

    agent.observe_transition(old_state, new_state, action, reward, terminal_flag)

    if terminal_flag:
        plt.close('all')

        agent.visualise(axes[0])
        fig.suptitle(f"{ns}\n{world.reward_cells}")
        fig.savefig('q_values')


world.reward_cells = {(2, 0)}

for ns in tqdm(range(num_steps)):

    action = agent.take_action(world.current_state)

    old_state, new_state, reward, terminal_flag = world.step(action)

    agent.observe_transition(old_state, new_state, action, reward, terminal_flag)

    latest_episode.append(new_state)

    if terminal_flag:
        plt.close('all')

        agent.visualise(axes[1])
        fig.suptitle(f"{ns}\n{world.reward_cells}")
        fig.savefig('q_values')

        latest_episode = [world.current_state]


100%|██████████| 5000/5000 [01:03<00:00, 78.90it/s] 
 79%|███████▊  | 3932/5000 [01:18<00:21, 49.98it/s]


KeyboardInterrupt: 

In [45]:
height = 5
width = 5


world = GridWorld(
    height = height,
    width = width,
    reward_cells = [(2, 2)],
    prob_correct = 0.7,
    term_reward = 10.0
)

agent = SuccessorRepresentationEpsilonGreedyQAgent(
    epsilon = 1.0,
    height = height,
    width = width,
    discount_factor = 0.99,
    learning_rate = 0.01,
    learning_rate_v = 1.0,
)

world.init()

num_steps = 5_000


fig, axes = plt.subplots(2, 5, figsize = (15, 8))


for ns in tqdm(range(num_steps)):

    action = agent.take_action(world.current_state)

    old_state, new_state, reward, terminal_flag = world.step(action)

    agent.observe_transition(old_state, new_state, action, reward, terminal_flag)

    if terminal_flag:
        plt.close('all')

        agent.visualise(axes[0], (2, 1))
        fig.suptitle(f"{ns}\n{world.reward_cells}")
        fig.savefig('sr_q_values')

world.reward_cells = {(2, 0)}

for ns in tqdm(range(num_steps)):

    action = agent.take_action(world.current_state)

    old_state, new_state, reward, terminal_flag = world.step(action)

    agent.observe_transition(old_state, new_state, action, reward, terminal_flag)

    if terminal_flag:
        plt.close('all')

        agent.visualise(axes[1], (2, 1))
        fig.suptitle(f"{ns}\n{world.reward_cells}")
        fig.savefig('sr_q_values')


100%|██████████| 1000/1000 [00:07<00:00, 136.27it/s]
100%|██████████| 1000/1000 [00:07<00:00, 135.50it/s]
