# XAI In Action - pgeon

This notebook shows the current functionalities of the **pgeon** library.

## Preparation

Loading an environment, an agent and a discretizer; the necessary elements to generate a Policy Graph.

In [31]:
import random

import gymnasium as gym

from example.cartpole.discretizer import CartpoleDiscretizer

In [32]:
import torch

In [33]:
environment = gym.make('CartPole-v1')
discretizer = CartpoleDiscretizer()

In [34]:
from pgeon import Agent
from ray.rllib.algorithms.algorithm import Algorithm

class CartpoleAgent(Agent):
    def __init__(self, path):
        self.agent = Algorithm.from_checkpoint(path)

    def act(self, state):
        return self.agent.compute_single_action(state)

In [35]:
agent = CartpoleAgent('checkpoints/PPO_CartPole-v1_1acbb_00000_0_2023-12-05_19-28-36/checkpoint_000000')

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))


## Policy Graph generation

In [36]:
from pgeon import PolicyGraph

Policy Graphs are instantiated with an environment and a discretizer.

In [37]:
pg = PolicyGraph(environment, discretizer)

We generate a Policy Graph with the `fit()` function, in this case generating 1000 episode trajectories from our agent. If the PG has been previously fit, one can choose to update the PG with new trajectories (instead of re-generating the PG) with `update=True`.

In [38]:
pg = pg.fit(agent, num_episodes=200, update=False)

Fitting PG...: 100%|██████████| 200/200 [00:15<00:00, 12.84it/s]


In [39]:
print(f'Number of nodes: {len(pg.nodes)}')
print(f'Number of edges: {len(pg.edges)}')

Number of nodes: 14
Number of edges: 132


ach node has information about a discretized state:

In [40]:
arbitrary_state = list(pg.nodes)[0]

print(arbitrary_state)
print(f'  Times visited: {pg.nodes[arbitrary_state]["frequency"]}')
print(f'  p(s):          {pg.nodes[arbitrary_state]["probability"]:.3f}')

(Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT))
  Times visited: 2557
  p(s):          0.071


Each edge has information about a transition between states:

In [41]:
arbitrary_edge = list(pg.edges)[0]

print(f'From:    {arbitrary_edge[0]}')
print(f'Action:  {arbitrary_edge[2]}')
print(f'To:      {arbitrary_edge[1]}')
print(f'  Times visited:      {pg[arbitrary_edge[0]][arbitrary_edge[1]][arbitrary_edge[2]]["frequency"]}')
print(f'  p(s_to,a | s_from): {pg[arbitrary_edge[0]][arbitrary_edge[1]][arbitrary_edge[2]]["probability"]:.3f}')

From:    (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT))
Action:  1
To:      (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_LEFT))
  Times visited:      379
  p(s_to,a | s_from): 0.148


The `PolicyGraph` object also stores the full discretized episode trajectories of the last fit.

In [42]:
len(pg._trajectories_of_last_fit)

200

Each trajectory is stored as a (state0, action0, state1, ..., stateN) tuple .

In [43]:
pg._trajectories_of_last_fit[0]

[(Position(MIDDLE), Velocity(LEFT), Angle(STUCK_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), An

## Loading and saving Policy Graphs

### Pickle

Saving as pickle lets you restore the full state of the object.

In [44]:
pg.save('pickle', './ppo-cartpole.pickle')

In [45]:
pg_pickle = PolicyGraph.from_pickle('./ppo-cartpole.pickle')

print(f'Number of nodes:             {len(pg_pickle.nodes)}')
print(f'Number of edges:             {len(pg_pickle.edges)}')
print(f'Num. of stored trajectories: {len(pg._trajectories_of_last_fit)}')

Number of nodes:             14
Number of edges:             132
Num. of stored trajectories: 200


### CSV

Saving as CSV creates three separated CSV files for node, edge and trajectory information.

In [46]:
import csv

In [47]:
pg.save('csv', ['./ppo-cartpole_nodes.csv', './ppo-cartpole_edges.csv', './ppo-cartpole_trajectories.csv'])

In [48]:
with open('ppo-cartpole_nodes.csv', 'r+') as f:
    csv_r = csv.reader(f)
    for i in range(10):
        print(next(csv_r))

['id', 'value', 'p(s)', 'frequency']
['0', 'Position(MIDDLE)&Velocity(LEFT)&Angle(STABILIZING_RIGHT)', '0.07128519654307221', '2557']
['1', 'Position(MIDDLE)&Velocity(LEFT)&Angle(STABILIZING_LEFT)', '0.0051017563423473656', '183']
['2', 'Position(MIDDLE)&Velocity(RIGHT)&Angle(STABILIZING_LEFT)', '0.22272093671591858', '7989']
['3', 'Position(MIDDLE)&Velocity(LEFT)&Angle(FALLING_LEFT)', '0.0046835795929746305', '168']
['4', 'Position(MIDDLE)&Velocity(RIGHT)&Angle(FALLING_RIGHT)', '0.20448843044326737', '7335']
['5', 'Position(MIDDLE)&Velocity(LEFT)&Angle(STUCK_LEFT)', '0.020797323668804015', '746']
['6', 'Position(MIDDLE)&Velocity(RIGHT)&Angle(STUCK_RIGHT)', '0.17056035684415946', '6118']
['7', 'Position(MIDDLE)&Velocity(LEFT)&Angle(FALLING_RIGHT)', '0.12863116810705325', '4614']
['8', 'Position(MIDDLE)&Velocity(LEFT)&Angle(STUCK_RIGHT)', '0.036297741845553386', '1302']


Edges and trajectories use the IDs of the nodes, from the corresponding node CSV file.

In [49]:
with open('ppo-cartpole_edges.csv', 'r+') as f:
    csv_r = csv.reader(f)
    for i in range(10):
        print(next(csv_r))

['from', 'to', 'action', 'p(s)', 'frequency']
['0', '5', '1', '0.1482205709816191', '379']
['0', '3', '1', '0.048494329292139225', '124']
['0', '0', '0', '0.1603441533046539', '410']
['0', '0', '1', '0.09777082518576456', '250']
['0', '8', '1', '0.03363316386390301', '86']
['0', '11', '1', '0.20375439968713335', '521']
['0', '9', '1', '0.12084473992960501', '309']
['0', '10', '1', '0.028549080954243255', '73']
['0', '7', '1', '0.03989049667579194', '102']


Each trajectory is stored as a series of (state0, action0, state1, ..., stateN) lists

In [50]:
with open('ppo-cartpole_trajectories.csv', 'r+') as f:
    csv_r = csv.reader(f)
    for i in range(1):
        print(next(csv_r))

['5', '0', '0', '1', '5', '0', '0', '1', '3', '0', '0', '1', '3', '0', '0', '1', '3', '0', '0', '0', '0', '1', '5', '0', '0', '1', '5', '0', '0', '1', '5', '1', '3', '0', '5', '0', '0', '1', '5', '0', '0', '1', '8', '0', '7', '1', '8', '0', '7', '1', '8', '0', '7', '1', '8', '1', '1', '0', '8', '1', '1', '0', '8', '1', '1', '0', '7', '0', '7', '1', '7', '1', '1', '0', '7', '0', '7', '1', '7', '1', '1', '0', '7', '1', '8', '0', '7', '1', '8', '1', '2', '0', '8', '0', '7', '1', '8', '0', '7', '1', '8', '1', '2', '0', '8', '0', '7', '1', '8', '0', '7', '1', '8', '1', '2', '0', '7', '1', '2', '0', '7', '1', '2', '0', '7', '1', '6', '1', '2', '0', '6', '1', '2', '0', '6', '0', '7', '1', '6', '1', '2', '0', '6', '1', '2', '0', '6', '1', '2', '0', '4', '0', '7', '1', '4', '1', '2', '1', '2', '0', '2', '1', '2', '0', '6', '0', '4', '1', '6', '0', '4', '1', '6', '0', '4', '1', '6', '1', '2', '1', '2', '0', '2', '0', '6', '0', '4', '1', '6', '1', '2', '0', '6', '1', '2', '0', '6', '0', '4', '0',

There are two ways of loading Policy Graphs from CSV files. When loading from nodes and edges, though, episode trajectories cannot be restored.

In [51]:
pg_csv = PolicyGraph.from_nodes_and_trajectories('./ppo-cartpole_nodes.csv', './ppo-cartpole_trajectories.csv',
                                          environment, discretizer)
print(f'Number of nodes:             {len(pg_csv.nodes)}')
print(f'Number of edges:             {len(pg_csv.edges)}')
print(f'Num. of stored trajectories: {len(pg_csv._trajectories_of_last_fit)}')

Number of nodes:             14
Number of edges:             132
Num. of stored trajectories: 200


In [52]:
pg_csv = PolicyGraph.from_nodes_and_edges('./ppo-cartpole_nodes.csv', './ppo-cartpole_edges.csv',
                                          environment, discretizer)
print(f'Number of nodes:             {len(pg_csv.nodes)}')
print(f'Number of edges:             {len(pg_csv.edges)}')
print(f'Num. of stored trajectories: {len(pg_csv._trajectories_of_last_fit)}')

Number of nodes:             14
Number of edges:             132
Num. of stored trajectories: 0


### Gram

PGs can also be exported to the [gram](https://neo4j.com/developer-blog/gram-a-data-graph-format/) format, allowing visualization using Neo4j. Episode trajectories cannot be stored in this format, though.

PGs currently cannot be loaded from a Gram file.

In [53]:
pg.save('gram', './ppo-cartpole.gram')

In [54]:
!head ./ppo-cartpole.gram


CREATE (s0:State {
  uid: "s0",
  value: "Position(MIDDLE)&Velocity(LEFT)&Angle(STABILIZING_RIGHT)",
  probability: 0.07128519654307221, 
  frequency:2557
});
CREATE (s1:State {
  uid: "s1",
  value: "Position(MIDDLE)&Velocity(LEFT)&Angle(STABILIZING_LEFT)",


In [55]:
!tail ./ppo-cartpole.gram

MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s0:State) WHERE s0.uid = "s0" CREATE (s13)-[:a0 {probability:0.10526315789473684, frequency:22}]->(s0);
MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s4:State) WHERE s4.uid = "s4" CREATE (s13)-[:a0 {probability:0.028708133971291867, frequency:6}]->(s4);
MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s4:State) WHERE s4.uid = "s4" CREATE (s13)-[:a1 {probability:0.019138755980861243, frequency:4}]->(s4);
MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s13:State) WHERE s13.uid = "s13" CREATE (s13)-[:a0 {probability:0.023923444976076555, frequency:5}]->(s13);
MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s13:State) WHERE s13.uid = "s13" CREATE (s13)-[:a1 {probability:0.03827751196172249, frequency:8}]->(s13);
MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s6:State) WHERE s6.uid = "s6" CREATE (s13)-[:a1 {probability:0.08133971291866028, frequency:17}]->(s6);
MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s5:State) WHERE s5.uid = "s5" CREA

## Using PG-based policies

Using the `PGBasedPolicy`, we can create policies that replicate an agent's behavior, based on their generated Policy Graph. These policies are subclasses of the `pgeon.Agent` class.

The policy mode (greedy/stochastic) can be specified via the `PGBasedPolicyMode` enum. The behavior when encountering an unknown node (select random action/search nearest node in PG) can be specified via the `PGBasedPolicyNodeNotFoundMode` enum.

In [56]:
from pgeon import PGBasedPolicy, PGBasedPolicyMode, PGBasedPolicyNodeNotFoundMode

In [57]:
policy = PGBasedPolicy(pg, mode=PGBasedPolicyMode.GREEDY,
                       node_not_found_mode=PGBasedPolicyNodeNotFoundMode.RANDOM_UNIFORM)

In [75]:
import random

environment = gym.make('CartPole-v1', render_mode='human')
obs, _ = environment.reset()
done = False
done_orig = False
steps = 0
list_states = []
while True:
    all_states = []
    while not done and steps < 200:
        action = policy.act(obs)
        obs, _, _, _, done = environment.step(action)
        state = policy.pg.discretizer.discretize(obs)
        if state in list_states:
            all_states.append(list_states.index(state))
        else:
            list_states.append(state)
        steps = steps + 1
    with open(f'/tmp/all-states.txt', 'a') as f:
        for state in all_states:
            f.write(f's{state}\n')
            f.flush()
    # print(f'Observed state:  {obs}')
    # print(f'Discretization:  {policy.pg.discretizer.discretize(obs)}')
    # print(f'Selected action: {action}')
    steps = 0
    obs, _ = environment.reset()

KeyboardInterrupt: 