# XAI In Action - pgeon

This notebook shows the current functionalities of the **pgeon** library.

## Preparation

Loading an environment, an agent and a discretizer; the necessary elements to generate a Policy Graph.

In [73]:
import gymnasium as gym

from example.cartpole.discretizer import CartpoleDiscretizer

In [74]:
import torch

In [75]:
environment = gym.make('CartPole-v1')
discretizer = CartpoleDiscretizer()

In [76]:
from pgeon import Agent
from ray.rllib.algorithms.algorithm import Algorithm

class CartpoleAgent(Agent):
    def __init__(self, path):
        self.agent = Algorithm.from_checkpoint(path)

    def act(self, state):
        return self.agent.compute_single_action(state)

In [77]:
agent = CartpoleAgent('checkpoints/PPO_CartPole-v1_1acbb_00000_0_2023-12-05_19-28-36/checkpoint_000000')

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))


## Policy Graph generation

In [78]:
from pgeon import PolicyGraph

Policy Graphs are instantiated with an environment and a discretizer.

In [79]:
pg = PolicyGraph(environment, discretizer)

We generate a Policy Graph with the `fit()` function, in this case generating 1000 episode trajectories from our agent. If the PG has been previously fit, one can choose to update the PG with new trajectories (instead of re-generating the PG) with `update=True`.

In [80]:
pg = pg.fit(agent, num_episodes=200, update=False)

Fitting PG...:  20%|██        | 40/200 [00:02<00:11, 14.31it/s][33m(raylet)[0m [2023-12-16 05:45:10,774 E 70095 3548731] (raylet) file_system_monitor.cc:111: /tmp/ray/session_2023-12-16_05-25-04_522475_70028 is over 95% full, available space: 24519659520; capacity: 494384795648. Object creation will fail if spilling is required.
Fitting PG...:  78%|███████▊  | 157/200 [00:13<00:05,  7.74it/s][33m(raylet)[0m [2023-12-16 05:45:20,858 E 70095 3548731] (raylet) file_system_monitor.cc:111: /tmp/ray/session_2023-12-16_05-25-04_522475_70028 is over 95% full, available space: 24521420800; capacity: 494384795648. Object creation will fail if spilling is required.
Fitting PG...: 100%|██████████| 200/200 [00:16<00:00, 11.79it/s]


In [81]:
print(f'Number of nodes: {len(pg.nodes)}')
print(f'Number of edges: {len(pg.edges)}')

Number of nodes: 14
Number of edges: 131


ach node has information about a discretized state:

In [82]:
arbitrary_state = list(pg.nodes)[0]

print(arbitrary_state)
print(f'  Times visited: {pg.nodes[arbitrary_state]["frequency"]}')
print(f'  p(s):          {pg.nodes[arbitrary_state]["probability"]:.3f}')

(Position(MIDDLE), Velocity(LEFT), Angle(STUCK_RIGHT))
  Times visited: 1501
  p(s):          0.036


Each edge has information about a transition between states:

In [83]:
arbitrary_edge = list(pg.edges)[0]

print(f'From:    {arbitrary_edge[0]}')
print(f'Action:  {arbitrary_edge[2]}')
print(f'To:      {arbitrary_edge[1]}')
print(f'  Times visited:      {pg[arbitrary_edge[0]][arbitrary_edge[1]][arbitrary_edge[2]]["frequency"]}')
print(f'  p(s_to,a | s_from): {pg[arbitrary_edge[0]][arbitrary_edge[1]][arbitrary_edge[2]]["probability"]:.3f}')

From:    (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_RIGHT))
Action:  1
To:      (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_LEFT))
  Times visited:      39
  p(s_to,a | s_from): 0.026


The `PolicyGraph` object also stores the full discretized episode trajectories of the last fit.

In [84]:
len(pg._trajectories_of_last_fit)

200

Each trajectory is stored as a (state0, action0, state1, ..., stateN) tuple .

In [85]:
pg._trajectories_of_last_fit[0]

[(Position(MIDDLE), Velocity(LEFT), Angle(STUCK_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_LEFT)),
 0,
 (Position(MIDDLE), Velocity(

## Loading and saving Policy Graphs

### Pickle

Saving as pickle lets you restore the full state of the object.

In [86]:
pg.save('pickle', './ppo-cartpole.pickle')

In [87]:
pg_pickle = PolicyGraph.from_pickle('./ppo-cartpole.pickle')

print(f'Number of nodes:             {len(pg_pickle.nodes)}')
print(f'Number of edges:             {len(pg_pickle.edges)}')
print(f'Num. of stored trajectories: {len(pg._trajectories_of_last_fit)}')

Number of nodes:             14
Number of edges:             131
Num. of stored trajectories: 200


### CSV

Saving as CSV creates three separated CSV files for node, edge and trajectory information.

In [88]:
import csv

In [89]:
pg.save('csv', ['./ppo-cartpole_nodes.csv', './ppo-cartpole_edges.csv', './ppo-cartpole_trajectories.csv'])

In [90]:
with open('ppo-cartpole_nodes.csv', 'r+') as f:
    csv_r = csv.reader(f)
    for i in range(10):
        print(next(csv_r))

['id', 'value', 'p(s)', 'frequency']
['0', 'Position(MIDDLE)&Velocity(LEFT)&Angle(STUCK_RIGHT)', '0.0363570303984498', '1501']
['1', 'Position(MIDDLE)&Velocity(RIGHT)&Angle(FALLING_RIGHT)', '0.1818578176093012', '7508']
['2', 'Position(MIDDLE)&Velocity(LEFT)&Angle(STABILIZING_RIGHT)', '0.07910863509749304', '3266']
['3', 'Position(MIDDLE)&Velocity(RIGHT)&Angle(FALLING_LEFT)', '0.06956521739130435', '2872']
['4', 'Position(MIDDLE)&Velocity(RIGHT)&Angle(STUCK_LEFT)', '0.0572362843647814', '2363']
['5', 'Position(MIDDLE)&Velocity(LEFT)&Angle(FALLING_RIGHT)', '0.12387065520164708', '5114']
['6', 'Position(MIDDLE)&Velocity(RIGHT)&Angle(STABILIZING_LEFT)', '0.21928061039118324', '9053']
['7', 'Position(MIDDLE)&Velocity(RIGHT)&Angle(STUCK_RIGHT)', '0.16691292236889912', '6891']
['8', 'Position(MIDDLE)&Velocity(LEFT)&Angle(FALLING_LEFT)', '0.004505268257236284', '186']


Edges and trajectories use the IDs of the nodes, from the corresponding node CSV file.

In [91]:
with open('ppo-cartpole_edges.csv', 'r+') as f:
    csv_r = csv.reader(f)
    for i in range(10):
        print(next(csv_r))

['from', 'to', 'action', 'p(s)', 'frequency']
['0', '10', '1', '0.02598267821452365', '39']
['0', '5', '0', '0.49233844103930713', '739']
['0', '6', '1', '0.47235176548967356', '709']
['0', '13', '0', '0.006662225183211193', '10']
['0', '12', '1', '0.001998667554963358', '3']
['0', '2', '0', '0.0006662225183211193', '1']
['1', '6', '1', '0.1709915323682054', '1252']
['1', '5', '0', '0.07702813438951106', '564']
['1', '7', '1', '0.2783392515706091', '2038']


Each trajectory is stored as a series of (state0, action0, state1, ..., stateN) lists

In [92]:
with open('ppo-cartpole_trajectories.csv', 'r+') as f:
    csv_r = csv.reader(f)
    for i in range(1):
        print(next(csv_r))

['9', '0', '2', '1', '9', '0', '2', '1', '8', '0', '2', '0', '2', '1', '2', '1', '8', '0', '2', '0', '2', '1', '2', '1', '8', '0', '2', '1', '8', '0', '9', '1', '8', '0', '9', '0', '2', '1', '9', '0', '2', '1', '9', '1', '8', '0', '9', '0', '2', '1', '9', '0', '2', '1', '0', '1', '10', '0', '0', '1', '10', '0', '9', '0', '2', '1', '0', '0', '5', '1', '0', '0', '5', '1', '0', '0', '5', '1', '0', '1', '10', '0', '0', '1', '10', '0', '5', '1', '10', '0', '5', '1', '10', '0', '5', '0', '5', '1', '5', '1', '10', '0', '5', '1', '10', '0', '5', '1', '0', '0', '5', '1', '0', '0', '5', '1', '0', '0', '5', '1', '0', '0', '5', '1', '0', '1', '6', '0', '0', '0', '5', '1', '0', '1', '6', '0', '5', '1', '6', '1', '6', '0', '6', '0', '5', '1', '7', '1', '6', '0', '7', '0', '5', '1', '7', '1', '6', '1', '6', '0', '6', '0', '7', '1', '6', '0', '7', '0', '5', '1', '7', '1', '6', '0', '7', '1', '6', '0', '7', '1', '6', '0', '7', '1', '6', '0', '1', '1', '6', '0', '1', '0', '5', '1', '1', '1', '6', '1', '

There are two ways of loading Policy Graphs from CSV files. When loading from nodes and edges, though, episode trajectories cannot be restored.

In [93]:
pg_csv = PolicyGraph.from_nodes_and_trajectories('./ppo-cartpole_nodes.csv', './ppo-cartpole_trajectories.csv',
                                          environment, discretizer)
print(f'Number of nodes:             {len(pg_csv.nodes)}')
print(f'Number of edges:             {len(pg_csv.edges)}')
print(f'Num. of stored trajectories: {len(pg_csv._trajectories_of_last_fit)}')

Number of nodes:             14
Number of edges:             131
Num. of stored trajectories: 200


In [94]:
pg_csv = PolicyGraph.from_nodes_and_edges('./ppo-cartpole_nodes.csv', './ppo-cartpole_edges.csv',
                                          environment, discretizer)
print(f'Number of nodes:             {len(pg_csv.nodes)}')
print(f'Number of edges:             {len(pg_csv.edges)}')
print(f'Num. of stored trajectories: {len(pg_csv._trajectories_of_last_fit)}')

Number of nodes:             14
Number of edges:             131
Num. of stored trajectories: 0


### Gram

PGs can also be exported to the [gram](https://neo4j.com/developer-blog/gram-a-data-graph-format/) format, allowing visualization using Neo4j. Episode trajectories cannot be stored in this format, though.

PGs currently cannot be loaded from a Gram file.

In [95]:
pg.save('gram', './ppo-cartpole.gram')

In [96]:
!head ./ppo-cartpole.gram


CREATE (s0:State {
  uid: "s0",
  value: "Position(MIDDLE)&Velocity(LEFT)&Angle(STUCK_RIGHT)",
  probability: 0.0363570303984498, 
  frequency:1501
});
CREATE (s1:State {
  uid: "s1",
  value: "Position(MIDDLE)&Velocity(RIGHT)&Angle(FALLING_RIGHT)",


In [97]:
!tail ./ppo-cartpole.gram

MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s0:State) WHERE s0.uid = "s0" CREATE (s13)-[:a1 {probability:0.1506849315068493, frequency:22}]->(s0);
MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s10:State) WHERE s10.uid = "s10" CREATE (s13)-[:a1 {probability:0.03424657534246575, frequency:5}]->(s10);
MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s7:State) WHERE s7.uid = "s7" CREATE (s13)-[:a1 {probability:0.18493150684931506, frequency:27}]->(s7);
MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s5:State) WHERE s5.uid = "s5" CREATE (s13)-[:a0 {probability:0.1506849315068493, frequency:22}]->(s5);
MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s5:State) WHERE s5.uid = "s5" CREATE (s13)-[:a1 {probability:0.1232876712328767, frequency:18}]->(s5);
MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s12:State) WHERE s12.uid = "s12" CREATE (s13)-[:a1 {probability:0.0410958904109589, frequency:6}]->(s12);
MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s2:State) WHERE s2.uid = "s2" CREATE (s

## Using PG-based policies

Using the `PGBasedPolicy`, we can create policies that replicate an agent's behavior, based on their generated Policy Graph. These policies are subclasses of the `pgeon.Agent` class.

The policy mode (greedy/stochastic) can be specified via the `PGBasedPolicyMode` enum. The behavior when encountering an unknown node (select random action/search nearest node in PG) can be specified via the `PGBasedPolicyNodeNotFoundMode` enum.

In [98]:
from pgeon import PGBasedPolicy, PGBasedPolicyMode, PGBasedPolicyNodeNotFoundMode

In [99]:
policy = PGBasedPolicy(pg, mode=PGBasedPolicyMode.GREEDY,
                       node_not_found_mode=PGBasedPolicyNodeNotFoundMode.RANDOM_UNIFORM)

In [100]:
environment = gym.make('CartPole-v1', render_mode='human')
obs, _ = environment.reset()
done = False
done_orig = False
steps = 0
while True:
    while not done and steps < 200:
        action = policy.act(obs)
        obs, _, _, _, done = environment.step(action)
        steps = steps + 1
    # print(f'Observed state:  {obs}')
    # print(f'Discretization:  {policy.pg.discretizer.discretize(obs)}')
    # print(f'Selected action: {action}')
    obs, _ = environment.reset()

[33m(raylet)[0m [2023-12-16 05:45:30,949 E 70095 3548731] (raylet) file_system_monitor.cc:111: /tmp/ray/session_2023-12-16_05-25-04_522475_70028 is over 95% full, available space: 24519815168; capacity: 494384795648. Object creation will fail if spilling is required.


KeyboardInterrupt: 

In [6]:
from IPython.display import HTML, Javascript, display
HTML("""
<div id="viz"></div>
<script src="neovis.js"></script>
<script type="text/Javascript">
    neoViz = null;
    marked = "s7";
    alert('0');

    config = {
            containerId: "viz",
            neo4j: {
                serverUrl: "bolt://localhost:7687",
                serverUser: "neo4j",
                serverPassword: "spain-joker-popular-exotic-parent-8142",
            },
            labels: {
                State: {
                    label: "uid",
                    value: "probability",
                    group: "community",
                    [NeoVis.NEOVIS_ADVANCED_CONFIG]: {
                        function: {
                            color: function (node) {
                                if (node.properties.uid == marked) {
                                    return "#ff0000";
                                } else {
                                    return null;
                                }
                            }
                        }
                    }
                }
            },
            relationships: {
                a0: {
                    label: "name",
                    caption: true,
                    value: "probability",
                    [NeoVis.NEOVIS_ADVANCED_CONFIG]: {
                        static: {
                            label: "L",     //content on edge
                            color: "#ff5378",
                            font: {
                                "background": "none",
                                "strokeWidth": "0",
                                "size": 20,
                                "color": "#000000"  //font color on edge
                            }
                        }
                    }
                },
                a1: {
                    label: "name",
                    value: "probability",
                    [NeoVis.NEOVIS_ADVANCED_CONFIG]: {
                        static: {
                            label: "R",     //content on edge
                            color: "#461274",
                            font: {
                                "background": "none",
                                "strokeWidth": "0",
                                "size": 20,
                                "color": "#000000"  //font color on edge
                            }
                        }
                    }
                }
            },
        };

    function draw() {
        neoViz = new NeoVis.default(config);
        neoViz.render();
    }
    alert('draw');
    draw();
    alert('draw2');
</script>
""")