In [1]:
import gym
import numpy as np
import random
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch_geometric.nn.conv import GCNConv
import math

Setup:
- Starting point: just try to train classifier on RL policies

In [2]:
### DQN implementation

# Define the Q-network
class DQN(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super(DQN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

    def forward(self, x):
        return self.fc(x)
    
    def get_weights(self):
        return self.state_dict()

# Experience Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        state = np.expand_dims(state, 0)
        next_state = np.expand_dims(next_state, 0)
            
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.concatenate(state), action, reward, np.concatenate(next_state), done

    def __len__(self):
        return len(self.buffer)

# DQN Agent
class DQNAgent:
    def __init__(self, state_dim, action_dim, hidden_dim=64, lr=1e-2, batch_size=64, gamma=0.99, replay_size=1000):
        self.model = DQN(state_dim, action_dim, hidden_dim)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.replay_buffer = ReplayBuffer(replay_size)
        self.batch_size = batch_size
        self.gamma = gamma
        self.action_dim = action_dim
    
    def update(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        state, action, reward, next_state, done = self.replay_buffer.sample(self.batch_size)
        
        state = torch.FloatTensor(state)
        next_state = torch.FloatTensor(next_state)
        if len(state.shape) == 1:
            state = state.reshape(-1, 1)
        if len(next_state.shape) == 1:
            next_state = next_state.reshape(-1, 1)
        action = torch.LongTensor(action)
        reward = torch.FloatTensor(reward)
        done = torch.FloatTensor(done)

        q_values = self.model.forward(state)
        next_q_values = self.model.forward(next_state)

        # state = state.T
        # next_state = next_state.T
        
        q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
        next_q_value = next_q_values.max(1)[0]
        expected_q_value = reward + self.gamma * next_q_value * (1 - done)
        
        loss = nn.MSELoss()(q_value, expected_q_value.detach())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def act(self, state, epsilon):
        if random.random() > epsilon:
            state = torch.FloatTensor(np.expand_dims(state, 0))
            q_value = self.model(state)
            action = q_value.max(-1)[1].item()
        else:
            action = random.randrange(self.action_dim)
        return action
    
class QTableAgent:
    def __init__(self, state_dim, action_dim, lr=1e-2, gamma=0.99):
        self.q_table = np.zeros((state_dim, action_dim))
        self.lr = lr
        self.gamma = gamma
        self.action_dim = action_dim
    
    def update(self, state, action, reward, next_state, done):
        q_value = self.q_table[state, action]
        next_q_value = np.max(self.q_table[next_state])
        expected_q_value = reward + self.gamma * next_q_value * (1 - done)
        
        self.q_table[state, action] += self.lr * (expected_q_value - q_value)
    
    def act(self, state, epsilon):
        if random.random() > epsilon:
            action = np.argmax(self.q_table[state])
        else:
            action = random.randrange(self.action_dim)
        return action

In [3]:
NUM_NON_ZERO_REWARDS = 0
def one_hot_state(state, env):
    state_arr = np.zeros(env.observation_space.n)
    state_arr[state] = 1
    return state_arr

def train_dqn(env_name="CartPole-v1", episodes=500, epsilon_start=1.0, epsilon_final=0.01, 
              epsilon_decay=500, reward_function = None, verbose = False, return_reward = False, 
              print_every=50, **kwargs):
    """
    Train a DQN agent on the specified environment.
    
    Args:
        env_name: str
            Name of the environment to train the agent on.
        episodes: int
            Number of episodes to train the agent for.
        epsilon_start: float
            Initial epsilon value for epsilon-greedy action selection.
        epsilon_final: float
            Final epsilon value for epsilon-greedy action selection.
        epsilon_decay: float
            Decay rate for epsilon.
        reward_function: function
            Optional reward function to use for training.
        verbose: bool
            Whether to print training progress.

    Returns:
        DQNAgent: trained DQN agent. 
    """
    global NUM_NON_ZERO_REWARDS
    env = gym.make(env_name)
    if len(env.observation_space.shape) == 0:
        state_dim = env.observation_space.n
    else:
        state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    agent = DQNAgent(state_dim, action_dim, **kwargs)
    
    epsilon_by_frame = lambda frame_idx: epsilon_final + (epsilon_start - epsilon_final) * np.exp(-1. * frame_idx / epsilon_decay)
    
    rewards = np.zeros(episodes) 
    is_state_discrete = hasattr(env.observation_space, 'n')
    for episode in range(episodes):
        state = env.reset() # Reset the environment, reward
        if is_state_discrete and state_dim == env.observation_space.n:
            state = one_hot_state(state, env)
        episode_reward = 0
        while True:
            epsilon = epsilon_by_frame(episode)
            # One-hot encode the state
            action = agent.act(state, epsilon)
            next_state, reward, done, _ = env.step(action)
            if is_state_discrete and state_dim == env.observation_space.n:
                next_state = one_hot_state(next_state, env)

            if reward_function: #custom reward function
                reward = reward_function(done, state, action, next_state)
            NUM_NON_ZERO_REWARDS += 0 if math.isclose(reward, 0) else 1
            
            agent.replay_buffer.push(state, action, reward, next_state, done)
            state = next_state
            episode_reward += reward
            
            agent.update()
            
            if done:
                break
            # print(f"Episode: {episode+1}, Total reward: {episode_reward}, Epsilon: {epsilon:.2f}")

        rewards[episode] = episode_reward
        # Optional: Render the environment to visualize training progress
        if verbose and episode % print_every == print_every - 1:
        #     render_env(env, agent)
            print(f"Episode: {episode+1}, Average total reward: {np.average(rewards[episode - print_every + 1 : episode])}, Epsilon: {epsilon:.2f}")

    env.close()
    return agent if not return_reward else (agent, rewards)

# Optional: Function to render the environment with the current policy
def render_env(env, agent):
    state = env.reset()
    done = False
    while not done:
        action = agent.act(state, 0)  # Using 0 epsilon for greedy action selection
        # print(env.step(action))
        next_state, reward, done, _ = env.step(action)
        env.render()
        state = next_state

In [4]:
def train_qtable(env_name="CartPole-v1", episodes=500, epsilon_start=1.0, epsilon_final=0.01, 
              epsilon_decay=500, reward_function = None, verbose = False, return_reward = False, 
              print_every=50, **kwargs):
    """
    Train a Q-table agent on the specified environment."""
    global NUM_NON_ZERO_REWARDS
    env = gym.make(env_name)
    if len(env.observation_space.shape) == 0:
        state_dim = env.observation_space.n
    else:
        state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    agent = QTableAgent(state_dim, action_dim, **kwargs)

    rewards = np.zeros(episodes)
    epsilon_by_frame = lambda frame_idx: epsilon_final + (epsilon_start - epsilon_final) * np.exp(-1. * frame_idx / epsilon_decay)
    for episode in range(episodes):
        state = env.reset()
        episode_reward = 0
        while True:
            epsilon = epsilon_by_frame(episode)
            action = agent.act(state, epsilon)
            next_state, reward, done, _ = env.step(action)
            if reward_function:
                reward = reward_function(done, state, action, next_state)

            agent.update(state, action, reward, next_state, done)
            state = next_state
            episode_reward += reward
            if done:
                break

        rewards[episode] = episode_reward
        if verbose and episode % print_every == print_every - 1:
            print(f"Episode: {episode+1}, Average total reward: {np.average(rewards[episode - print_every + 1 : episode])}, Epsilon: {epsilon:.2f}")
        
    env.close()
    return agent if not return_reward else (agent, rewards)
    

In [5]:
NEAR_ZERO = 1e-9
def test_dqn(env, agent, episodes=10, reward_function=None, verbose = False):
    print(f"Maximum reward: {env.spec.reward_threshold}")
    average_value = 0
    for episode in range(episodes):
        # if episode == 0:
        #     render_env(env, agent)
        state = env.reset()
        if len(env.observation_space.shape) == 0:
            state = one_hot_state(state, env)
        episode_reward = 0
        done = False
        while not done:
            action = agent.act(state, 0)  # Using 0 epsilon for greedy action selection
            next_state, reward, done, _ = env.step(action)
            if len(env.observation_space.shape) == 0:
                next_state = one_hot_state(next_state, env)
            if reward_function:
                reward = reward_function(done, state, action, next_state)
            episode_reward += reward
            state = next_state
        if verbose:
            print(f"Episode: {episode+1}, Total reward: {episode_reward}")
        average_value += episode_reward
    average_value /= episodes
    print(f"Average reward: {average_value}")
    

In [6]:
def test_qtable(env, agent, episodes=10, reward_function=None, verbose = False):
    """
    Test a Q-table agent on the specified environment.
    (This is basically test_dqn but without the one-hot encoding.)
    """
    print(f"Maximum reward: {env.spec.reward_threshold}")
    average_value = 0
    for episode in range(episodes):
        state = env.reset()
        episode_reward = 0
        done = False
        while not done:
            action = agent.act(state, 0)  # Using 0 epsilon for greedy action selection
            next_state, reward, done, _ = env.step(action)
            if reward_function:
                reward = reward_function(done, state, action, next_state)
            episode_reward += reward
            state = next_state
        if verbose:
            print(f"Episode: {episode+1}, Total reward: {episode_reward}")
        average_value += episode_reward
    average_value /= episodes
    print(f"Average reward: {average_value}")

In [7]:
env_name = "Taxi-v3"
agent, rewards = train_qtable(env_name = env_name, episodes = 10000, verbose = True, return_reward = True,
                           epsilon_decay=10, lr=0.1, gamma=0.9)
# rewards

Episode: 50, Average total reward: -398.55102040816325, Epsilon: 0.02
Episode: 100, Average total reward: -222.81632653061226, Epsilon: 0.01


Episode: 150, Average total reward: -191.3469387755102, Epsilon: 0.01
Episode: 200, Average total reward: -179.83673469387756, Epsilon: 0.01
Episode: 250, Average total reward: -163.0, Epsilon: 0.01
Episode: 300, Average total reward: -137.6734693877551, Epsilon: 0.01
Episode: 350, Average total reward: -128.9795918367347, Epsilon: 0.01
Episode: 400, Average total reward: -112.12244897959184, Epsilon: 0.01
Episode: 450, Average total reward: -79.36734693877551, Epsilon: 0.01
Episode: 500, Average total reward: -80.20408163265306, Epsilon: 0.01
Episode: 550, Average total reward: -60.6530612244898, Epsilon: 0.01
Episode: 600, Average total reward: -42.734693877551024, Epsilon: 0.01
Episode: 650, Average total reward: -44.69387755102041, Epsilon: 0.01
Episode: 700, Average total reward: -35.08163265306123, Epsilon: 0.01
Episode: 750, Average total reward: -29.285714285714285, Epsilon: 0.01
Episode: 800, Average total reward: -22.081632653061224, Epsilon: 0.01
Episode: 850, Average total 

In [8]:
test_qtable(gym.make(env_name), agent, episodes = 100)
# agent.model.get_weights()

Maximum reward: 8
Average reward: 8.12


In [9]:
### Coherence classifier

#agent.model.get_weights()

# Define a simple GCN model
from torch_geometric.data import Data
class GCN(torch.nn.Module):
    def __init__(self, data):
        super(GCN, self).__init__()
        # Define the GCN layers
        self.conv1 = GCNConv(data.num_node_features, 4)  # Input features to hidden
        self.conv2 = GCNConv(4, 2)  # Hidden to output features
        self.data = data

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        # Pass data through the first GCN layer, then apply ReLU
        x = torch.relu(self.conv1(x, edge_index))
        # Pass data through the second GCN layer
        x = self.conv2(x, edge_index)
        return x
    
    
def nn_to_data(model: nn.Module) -> Data:
    edges = []

    # Counter for global neuron index
    idx = 0

    # Iterate over each layer in the network
    base = next(model.children())
    if isinstance(base, nn.Sequential):
        layers = list(base.children())
        layers2 = list(base.children())
    else:
        layers = list(model.children()) # iterator over the layers of the model
        layers2 = list(model.children())
    
    num_nodes = layers2[0].weight.shape[1] + sum([layer.weight.shape[0] for layer in layers2 if isinstance(layer, nn.Linear)])
    num_node_features = num_nodes
    node_features = torch.zeros(num_nodes, num_node_features)
    # shape = (num_nodes, num_node_features), where the node features are the bias of each node
    # and the weights of the edges to each node (zero if there is no edge)

    for layer in layers:
        if isinstance(layer, nn.Linear):
            # Update edges based on the weight matrix
            input_dim = layer.weight.shape[1]
            output_dim = layer.weight.shape[0]
            for i in range(input_dim):  # Input neurons (e.g. 4)
                for j in range(output_dim):  # Output neurons (e.g. 64)
                    edges.append((idx + i, idx + input_dim + j))
            
            # Update node features (e.g., biases)
            biases = torch.tensor(layer.bias.detach().numpy())
            edge_weights = torch.tensor(layer.weight.detach().numpy().T)
            node_features[idx + input_dim:idx + input_dim + output_dim, 0] = biases
            node_features[idx:idx + input_dim, 1+idx:1+idx+output_dim] = edge_weights
            node_features[idx + input_dim:idx + input_dim + output_dim, 1+idx:1+idx+input_dim] = edge_weights.T
            
            # Update the global neuron index
            idx += input_dim

    # Convert lists to PyTorch tensors
    num_nonzero = [np.count_nonzero(node_features[i]) for i in range(node_features.shape[0])]
    # print(num_nonzero)
    row_mean, row_median, row_var = torch.mean(node_features[:, 1:], dim=1), torch.median(node_features[:, 1:], dim=1)[0], torch.var(node_features[:, 1:], dim=1)
    x = torch.stack([node_features[:, 0], row_mean, row_median, row_var]).T
    # print(x.shape)
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    return Data(x=x, edge_index=edge_index)

agent = train_dqn(env_name = "CartPole-v1", episodes = 1000, verbose = True, return_reward = False)
data = nn_to_data(agent.model)
gcn = GCN(data)
# data.x.shape, data.edge_index.shape
# print(data.x)

#Debug
out_of_bounds = data.edge_index >= data.x.shape[0]
if out_of_bounds.any():
    print("Out-of-bounds indices found at locations:")
    print(data.edge_index[:, out_of_bounds.any(dim=0)])

Episode: 50, Average total reward: 18.102040816326532, Epsilon: 0.91
Episode: 100, Average total reward: 22.224489795918366, Epsilon: 0.82
Episode: 150, Average total reward: 20.612244897959183, Epsilon: 0.74
Episode: 200, Average total reward: 17.408163265306122, Epsilon: 0.67
Episode: 250, Average total reward: 16.53061224489796, Epsilon: 0.61
Episode: 300, Average total reward: 16.163265306122447, Epsilon: 0.55
Episode: 350, Average total reward: 14.755102040816327, Epsilon: 0.50
Episode: 400, Average total reward: 12.755102040816327, Epsilon: 0.46
Episode: 450, Average total reward: 13.0, Epsilon: 0.41
Episode: 500, Average total reward: 12.0, Epsilon: 0.37
Episode: 550, Average total reward: 13.63265306122449, Epsilon: 0.34
Episode: 600, Average total reward: 14.142857142857142, Epsilon: 0.31
Episode: 650, Average total reward: 15.428571428571429, Epsilon: 0.28
Episode: 700, Average total reward: 18.26530612244898, Epsilon: 0.25
Episode: 750, Average total reward: 45.6734693877551

In [10]:
# Dataset generation
env_name = "CartPole-v1"
env = gym.make(env_name)
NEAR_ZERO = 1e-9
NUM_REWARD_CALLS = 0
NUM_NON_ZERO_REWARDS = 0
def deterministic_random(*args, lb = -1, ub = 1, sparsity = 0.0, continuous = False):
    """
    Create a deterministic random number generator for a given set of arguments.
    Used to generate deterministic reward functions for the coherence classifier.
    [Edit 4/3/24: adapted to continuous state space]"""
    global NUM_REWARD_CALLS
    NUM_REWARD_CALLS += 1
    unique_seed = f"{args}".encode("utf-8")
    random.seed(unique_seed)
    return random.uniform(lb, ub) if random.random() > sparsity else random.uniform(-NEAR_ZERO, NEAR_ZERO)

NUM_TRAIN_R_FUNCS = 50
NUM_EPS_TRAIN_R = 50
URS_r_funcs = [lambda *args: deterministic_random(args) for _ in range(NUM_TRAIN_R_FUNCS)]
URS_agents = [train_dqn(env_name = env_name, 
                        episodes=NUM_EPS_TRAIN_R, reward_function=r_func) for r_func in URS_r_funcs]
USS_r_funcs = [lambda *args: deterministic_random(args, sparsity=0.99) for _ in range(NUM_TRAIN_R_FUNCS)]
print(f"Number of reward function calls: {NUM_REWARD_CALLS}")
print(f"Number of non-zero rewards: {NUM_NON_ZERO_REWARDS}")
USS_agents = [train_dqn(env_name = env_name, 
                        episodes=NUM_EPS_TRAIN_R, reward_function=r_func) for r_func in USS_r_funcs]

Number of reward function calls: 58310
Number of non-zero rewards: 58310


In [18]:
# Test if deterministic_random is deterministic and has the correct sparsity
assert deterministic_random(1, 2, 3, 4) == deterministic_random(1, 2, 3, 4)
assert not deterministic_random(1, 2, 3, 4) == deterministic_random(1, 2, 3, 6)
[deterministic_random(1, 2, 3, i, sparsity = 0.5) for i in range(10)]

[0.6014137224608205,
 5.734868810378968e-10,
 0.18947200717913826,
 -0.11464719428521586,
 1.9375194864306798e-11,
 3.1131227593489704e-10,
 -7.023178277046693e-10,
 2.965355797951794e-10,
 0.41831271768541045,
 -3.207247699354683e-10]

In [11]:
# Test when do USS agents have non-zero rewards
env_name = "CartPole-v1"
USS_test_r_func = lambda *args: deterministic_random(args, sparsity=0.0)
assert USS_test_r_func(42) == USS_test_r_func(42)
USS_test_agent = train_dqn(env_name = env_name, episodes=500, reward_function=USS_test_r_func, 
                           verbose = True)
# Epsilon measuring how much the agent is exploring

Episode: 50, Average total reward: -0.7109034800800098, Epsilon: 0.91
Episode: 100, Average total reward: -0.17504186308399963, Epsilon: 0.82
Episode: 150, Average total reward: -0.82087640624123, Epsilon: 0.74
Episode: 200, Average total reward: -0.3467656369794918, Epsilon: 0.67
Episode: 250, Average total reward: -0.3202516562303173, Epsilon: 0.61
Episode: 300, Average total reward: -1.2871440511690573, Epsilon: 0.55
Episode: 350, Average total reward: 0.14199784219845837, Epsilon: 0.50
Episode: 400, Average total reward: 0.37669043104779626, Epsilon: 0.46
Episode: 450, Average total reward: -0.5010131068566521, Epsilon: 0.41
Episode: 500, Average total reward: -0.5665183593912818, Epsilon: 0.37


In [12]:
# epsilon_final, epsilon_start, epsilon_decay = 0.01, 1.0, 500
# [epsilon_final + (epsilon_start - epsilon_final) * np.exp(-1. * frame_idx / epsilon_decay) for frame_idx in range(500)]
test_dqn(gym.make(env_name), USS_test_agent, reward_function=USS_test_r_func)

Maximum reward: 475.0


Average reward: -0.5560146586302399


In [13]:
def get_state_shape(env):
    return 1 if len(env.observation_space.shape) == 0 else env.observation_space.shape[0]
def get_state_size(env):
    return env.observation_space.n if len(env.observation_space.shape) == 0 else env.observation_space.shape[0]
UPS_agents = [DQNAgent(get_state_size(env), env.action_space.n) for _ in range(NUM_TRAIN_R_FUNCS)]

import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool, GCNConv, GATConv

class GraphLevelGCN(torch.nn.Module):
    def __init__(self, num_node_features):
        super(GraphLevelGCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 16)
        self.conv2 = GCNConv(16, 16)
        self.linear = torch.nn.Linear(16, 1)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        # edge_weights = data.edge_attr
        
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        
        # Aggregate node features to graph-level features
        x = global_mean_pool(x, batch)
        
        # Make a binary classification prediction
        x = self.linear(x)
        return torch.sigmoid(x)

class GATGraphLevelBinary(torch.nn.Module):
    def __init__(self, num_node_features):
        super(GATGraphLevelBinary, self).__init__()
        self.conv1 = GATConv(num_node_features, 8, heads=8, dropout=0.6)
        # Increase the number of output features from the first GAT layer
        self.conv2 = GATConv(8 * 8, 16, heads=1, concat=False, dropout=0.6)
        # Additional GAT layer for richer node representations
        self.linear = torch.nn.Linear(16, 1)
        # Final linear layer to produce a graph-level output

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        
        x = global_mean_pool(x, batch)  # Aggregate node features to graph-level
        x = self.linear(x)
        return torch.sigmoid(x)  # Sigmoid activation function for binary classification

# Training loop
USS_data = [nn_to_data(agent.model) for agent in USS_agents]
URS_data = [nn_to_data(agent.model) for agent in URS_agents]
print(URS_data[0].x.shape)
UPS_data = [nn_to_data(agent.model) for agent in UPS_agents]
assert URS_data[0].x.shape == UPS_data[0].x.shape

torch.Size([134, 4])


In [14]:
# Binary classification between two datasets
dataset1 = USS_data
dataset2 = URS_data
def generate_data(dataset1, dataset2):
    indices = np.random.permutation(len(dataset1) + len(dataset2))
    data = [dataset1[i] if i < len(dataset1) else dataset2[i - len(dataset1)] for i in indices]
    for i in range(len(data)):
        data[i].y = 1.0 if indices[i] < len(dataset1) else 0.0 # Binary labels for each node; 1 = URS, 0 = UPS
        # Hence roughly speaking, 1 = more coherent, 0 = less coherent

    train_data_ratio = 0.8
    train_data, test_data = data[:int(train_data_ratio * len(data))], data[int(train_data_ratio * len(data)):]
    num_node_features = data[0].x.shape[1] # Number of features for each node
    return train_data, test_data, num_node_features

train_data, test_data, num_node_features = generate_data(dataset1, dataset2)
# Loss and optimizer
model = GraphLevelGCN(num_node_features)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

def train_classifier(model, criterion, optimizer, train_data, test_data, epochs = 40, patience = 3, 
                     epochs_without_improvement = 0, best_loss = float('inf')):
    for epoch in range(epochs):
        avg_train_loss = 0
        for datapt in train_data:
            model.train()
            optimizer.zero_grad()

            # print(f"datapt.x shape: {datapt.x.shape}")  # Should be [num_nodes, num_node_features]
            # print(f"datapt.edge_index shape: {datapt.edge_index.shape}")  # Should be [2, num_edges]
            out = model.forward(datapt)
            # print(out.size())
            # print(torch.tensor([[datapt.y]]).size())
            loss = criterion(out, torch.tensor([[datapt.y]]))  # Adjust shape as necessary
            loss.backward()
            optimizer.step()
            avg_train_loss += loss.item()
        avg_train_loss /= len(train_data)

        avg_test_loss = 0
        for datapt in test_data:
            model.eval()
            with torch.no_grad():
                out = model.forward(datapt)
                loss = criterion(out, torch.tensor([[datapt.y]]))
                avg_test_loss += loss.item()
        avg_test_loss /= len(test_data)
        
        print(f'Epoch {epoch+1}: Average Train Loss: {avg_train_loss}, Average Test Loss: {avg_test_loss}')
        
        # Early Stopping
        if avg_test_loss < best_loss:
            best_loss = avg_test_loss
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= patience:
                print(f'Early stopping at epoch {epoch+1}')
                break

train_classifier(model, criterion, optimizer, train_data, test_data)

Epoch 1: Average Train Loss: 0.96087555669219, Average Test Loss: 1.4202829338610172
Epoch 2: Average Train Loss: 0.7309860734501854, Average Test Loss: 0.843775562942028
Epoch 3: Average Train Loss: 0.7203299042768776, Average Test Loss: 0.7586588948965073
Epoch 4: Average Train Loss: 0.7053088508546352, Average Test Loss: 0.7255915313959121
Epoch 5: Average Train Loss: 0.6966455757617951, Average Test Loss: 0.7026014655828476
Epoch 6: Average Train Loss: 0.7097975861281156, Average Test Loss: 0.7265136957168579
Epoch 7: Average Train Loss: 0.6949755385518074, Average Test Loss: 0.7046276897192001
Epoch 8: Average Train Loss: 0.694453526288271, Average Test Loss: 0.7069998249411583
Early stopping at epoch 8


In [25]:
# Test GCN model on a "more powerful" NN
print(model.forward(dataset1[0]))
print(model.forward(dataset2[0]))
powerful_models = [nn_to_data(train_dqn(env_name = env_name, episodes = 5 * i).model) 
                   for i in [1, 3, 10]]
print([model.forward(data) for data in powerful_models])

tensor([[0.3876]], grad_fn=<SigmoidBackward0>)
tensor([[0.3860]], grad_fn=<SigmoidBackward0>)


[tensor([[0.3863]], grad_fn=<SigmoidBackward0>), tensor([[0.5091]], grad_fn=<SigmoidBackward0>), tensor([[0.5598]], grad_fn=<SigmoidBackward0>)]


- The classifier training process is finicky -- sometimes it overfits, sometimes it underfits -- but sometimes can reach very low loss (< 0.002)
- Even weak classifiers classify powerful models (a.k.a. agents with >15 episodes in CartPole) as having P(URS) = 1, corresponding to coherence ~ $\infty$
- P(USS) / P(URS) is still having trouble as a metric; seems extremely difficult to detect differences between USS and URS-generated policies here with current methods
    - We will need some kind of "more advanced" coherence metric to distinguish more advanced policies; TODO: implement UUS somehow
- Adding node weights to every other node to the features passed into the GCN (such that, in CartPole, the data matrix has shape (134, 134) instead of (134, 1)) makes the GCN much worse, probably because of higher dimensionality
    - Using attention in the GNN does not help, and in fact actively overfits when using (134, 1) data
- Even with sparsity = 0.999, USS is still hard to distinguish
- For simpler discrete environments, maybe a Q-table is enough to solve the problem
- Takes >= 500 episodes, small epsilon to effectively learn DQN policy


In [26]:
# Now classifying q-table agents
env_name = "Taxi-v3"
NUM_EPS_TRAIN_R = 1000
NUM_TRAIN_R_FUNCS = 50
NUM_REWARD_CALLS = 0
env = gym.make(env_name)
def deterministic_random(*args, lb = -1, ub = 1, sparsity = 0.0, continuous = False):
    """
    Create a deterministic random number generator for a given set of arguments.
    Used to generate deterministic reward functions for the coherence classifier.
    [Edit 4/3/24: adapted to continuous state space]"""
    global NUM_REWARD_CALLS
    NUM_REWARD_CALLS += 1
    unique_seed = f"{args}".encode("utf-8")
    random.seed(unique_seed)
    return random.uniform(lb, ub) if random.random() > sparsity else random.uniform(-NEAR_ZERO, NEAR_ZERO)
URS_r_funcs = [lambda *args: deterministic_random(args) for _ in range(NUM_TRAIN_R_FUNCS)]
URS_agents = [train_qtable(env_name = env_name, episodes=NUM_EPS_TRAIN_R, 
                           reward_function = r_func) for r_func in URS_r_funcs]
USS_r_funcs = [lambda *args: deterministic_random(args, sparsity=0.99) for _ in range(NUM_TRAIN_R_FUNCS)]
USS_agents = [train_qtable(env_name = env_name, episodes=NUM_EPS_TRAIN_R,
                            reward_function = r_func) for r_func in USS_r_funcs]
UPS_agents = [QTableAgent(get_state_size(env), env.action_space.n) for _ in range(NUM_TRAIN_R_FUNCS)]

# The Q-Table is already one-hot encoded, so we don't need to convert it to a Data object

In [43]:
# Test ground
print([USS_r_funcs[0](i) for i in range(10)])
test_USS_agent = train_qtable(env_name = env_name, episodes = 50, verbose=True, epsilon_decay = 100, 
                              lr = 0.01, gamma = 0.9, reward_function = USS_r_funcs[0])
test_qtable(gym.make(env_name), test_USS_agent, episodes = 100, reward_function = USS_r_funcs[0])
test_UPS_agent = QTableAgent(get_state_size(env), env.action_space.n)
test_qtable(gym.make(env_name), test_UPS_agent, episodes = 100, reward_function = USS_r_funcs[0])

[6.879610979477867e-10, -8.537837971656553e-10, 1.3453720557822044e-10, -6.91582481382308e-10, 5.834328244211258e-10, 2.689469001408139e-10, -7.747733415111116e-10, 3.5242825351422126e-10, 2.6150563760737437e-10, -8.919200340033589e-10]
Episode: 50, Average total reward: 0.013892658673890956, Epsilon: 0.62
Maximum reward: 8
Average reward: 1.8141342803036227
Maximum reward: 8
Average reward: 0.019907783115625098


In [28]:
class FCNNBinary(nn.Module):
    def __init__(self, num_node_features):
        super(FCNNBinary, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(num_node_features, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, data):
        x = data.x
        x = self.fc(x)
        return torch.sigmoid(x)
    
def qtable_to_feat(qtable: torch.Tensor, label):
    # In qtable, rows are states and columns are actions taken in that state
    return Data(x = torch.flatten(qtable), y = label) # Naive approach

In [29]:
UPS_agents = [QTableAgent(get_state_size(env), env.action_space.n) for _ in range(NUM_TRAIN_R_FUNCS)]

for agent in UPS_agents:
    for row in agent.q_table:
        for i in range(len(row)):
            row[i] = np.random.uniform(-1, 1) # set each value to a random number between -1 and 1
dataset1 = [qtable_to_feat(torch.tensor(agent.q_table, dtype=torch.float32), 1) for agent in USS_agents]
dataset2 = [qtable_to_feat(torch.tensor(agent.q_table, dtype=torch.float32), 0) for agent in URS_agents] # URS = 1, UPS = 0

def generate_fcnn_data(dataset1, dataset2):
    indices = np.random.permutation(len(dataset1) + len(dataset2))
    data = [dataset1[i] if i < len(dataset1) else dataset2[i - len(dataset1)] for i in indices]
    for i in range(len(data)):
        data[i].y = 1.0 if indices[i] < len(dataset1) else 0.0 # Binary labels for each node; 1 = URS, 0 = UPS
        # Hence roughly speaking, 1 = more coherent, 0 = less coherent

    train_data_ratio = 0.8
    train_data, test_data = data[:int(train_data_ratio * len(data))], data[int(train_data_ratio * len(data)):]
    num_node_features = data[0].x.shape[0] # Number of features for each node
    return train_data, test_data, num_node_features

def train_fcnn_classifier(model, criterion, optimizer, train_data, test_data, epochs = 40, patience = 3, 
                          epochs_without_improvement = 0, best_loss = float('inf')):
    for epoch in range(epochs):
        avg_train_loss = 0
        for datapt in train_data:
            model.train()
            optimizer.zero_grad()

            out = model.forward(datapt)
            assert isinstance(out, torch.Tensor), f"Expected model.forward to return a tensor, but got {out}"
            loss = criterion(out, torch.tensor([datapt.y]))  # Adjust shape as necessary
            loss.backward()
            optimizer.step()
            avg_train_loss += loss.item()
        avg_train_loss /= len(train_data)

        avg_test_loss = 0
        for datapt in test_data:
            model.eval()
            with torch.no_grad():
                out = model.forward(datapt)
                loss = criterion(out, torch.tensor([datapt.y]))
                avg_test_loss += loss.item()
        avg_test_loss /= len(test_data)
        
        print(f'Epoch {epoch+1}: Average Train Loss: {avg_train_loss}, Average Test Loss: {avg_test_loss}')
        
        # Early Stopping
        if avg_test_loss < best_loss:
            best_loss = avg_test_loss
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= patience:
                print(f'Early stopping at epoch {epoch+1}')
                break

train_data, test_data, num_node_features = generate_fcnn_data(dataset1, dataset2)
print(num_node_features)
model = FCNNBinary(num_node_features)
criterion = torch.nn.BCELoss()  
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
train_fcnn_classifier(model, criterion, optimizer, train_data, test_data)

3000
Epoch 1: Average Train Loss: 0.11231718239287707, Average Test Loss: 0.00022335199305162818
Epoch 2: Average Train Loss: 0.0002989479297568248, Average Test Loss: 8.912703181408952e-05
Epoch 3: Average Train Loss: 0.00014629946607331237, Average Test Loss: 4.654434895650121e-05
Epoch 4: Average Train Loss: 8.502696133767995e-05, Average Test Loss: 2.7614850406871973e-05
Epoch 5: Average Train Loss: 5.480071272152748e-05, Average Test Loss: 1.7909364215942464e-05
Epoch 6: Average Train Loss: 3.78367068058298e-05, Average Test Loss: 6.568596279142097e-06
Epoch 7: Average Train Loss: 3.982042360476125e-06, Average Test Loss: 5.3644196640334485e-08
Epoch 8: Average Train Loss: 2.2202764586864988e-07, Average Test Loss: 2.384186146286993e-08
Epoch 9: Average Train Loss: 1.2516986487298708e-07, Average Test Loss: 1.1920930376163597e-08
Epoch 10: Average Train Loss: 8.344655171299564e-08, Average Test Loss: 5.9604651880817984e-09
Epoch 11: Average Train Loss: 6.109478771776366e-08, Avera

In [30]:
def generate_UQS_qagent(rand_qtable, gamma, env: gym.Env, episodes = 500):
    """
    Train a Q-table agent based on a reward function uniformly sampled from the set of 
    possible reward functions compatible with the given random Q-table."""
    # Generate the reward function using the Bellman equation
    r_table = np.zeros(rand_qtable.shape)
    for s in range(rand_qtable.shape[0]):
        for a in range(rand_qtable.shape[1]):
            env.reset()
            env.unwrapped.s = s
            ns = env.step(a)[0]
            r_table[s, a] = rand_qtable[s, a] - gamma * np.max(rand_qtable[ns]) #assuming greedy policy
    
    # Train the agent
    r_func = lambda s, a, *args: r_table[s, a]
    return train_qtable(env_name = env.spec.id, episodes = episodes, reward_function = r_func)

UQS_agents = [generate_UQS_qagent(agent.q_table, 0.9, env, episodes = NUM_EPS_TRAIN_R) for agent in UPS_agents]

In [31]:
def generate_UVS_qagent(rand_values, gamma, env: gym.Env, episodes = 500, lb = -1, ub = 1):
    """
    Train a Q-table agent based on a reward function uniformly sampled from the set of 
    possible reward functions compatible with the given values for each state.
    Assumes a uniform distribution between [lb, ub]."""
    r_table = np.zeros((len(rand_values), env.action_space.n))
    for s in range(len(rand_values)):
        next_states = np.zeros(env.action_space.n)
        for a in range(env.action_space.n):
            env.reset()
            env.unwrapped.s = s
            next_states[a] = env.step(a)[0]
        #v(s) = max_a(R(s, a) + gamma * v(s'))
        reward_ub = rand_values[s] - np.array([gamma * rand_values[int(ns)] for ns in next_states])
        taut_probs = np.zeros(env.action_space.n)
        for i in range(env.action_space.n):
            all_except_i = np.delete(np.arange(env.action_space.n), i)
            taut_probs[i] = np.prod((reward_ub[all_except_i] + 1) / 2)
            # probability that all other rewards at action j are less than reward_ub[j]
        taut_probs /= np.sum(taut_probs)
        taut = np.random.choice(env.action_space.n, p = taut_probs) 
        #index of the action where the reward is equal to the maximum

        rewards = np.full(env.action_space.n, float('inf'))
        while np.any(rewards >= reward_ub): #while any of the rewards are greater than the upper bound
            rewards = np.random.uniform(-1, 1, env.action_space.n)
        rewards[taut] = reward_ub[taut]
        r_table[s] = rewards
    
    r_func = lambda s, a, *args: r_table[s, a]
    return train_qtable(env_name = env.spec.id, episodes = episodes, reward_function = r_func)

# UVS_agents = [generate_UVS_qagent(np.random.uniform(-1, 1, env.unwrapped.s), 0.9, env, episodes = NUM_EPS_TRAIN_R) for _ in range(NUM_TRAIN_R_FUNCS)]
# this currently takes way too long so it has been commented out

In [32]:
UQS_agents[0].q_table - UPS_agents[0].q_table

array([[-0.19997789,  0.8480003 , -0.94945213,  0.36469594, -0.21314675,
         0.29373737],
       [-0.51630777, -0.73252197, -0.82781936,  0.27989021, -0.46934794,
        -0.20870209],
       [ 0.03481552,  0.09088575, -0.27765734, -0.20246238, -0.03002716,
        -0.28835059],
       ...,
       [-0.21234347, -0.22230378,  0.44615135, -0.80687533, -0.02927372,
        -0.87625879],
       [ 0.93125485,  0.60710758, -0.53284528, -0.6932406 , -0.1986679 ,
        -0.55611788],
       [ 0.64275359, -0.31388159, -0.50294809,  0.34698775,  0.06680938,
         0.12857975]])

In [41]:
### Only attaching reward to terminal states (kind of like UUS? but with the inductive biases)

def det_rand_terminal(done: bool, *args, lb = -1, ub = 1, sparsity = 0.0):
    """
    Create a deterministic random number generator for a given set of arguments.
    Used to generate deterministic reward functions for the coherence classifier. """
    global NUM_REWARD_CALLS
    NUM_REWARD_CALLS += 1
    if not done:
        return random.uniform(-NEAR_ZERO, NEAR_ZERO)
    unique_seed = f"{args}".encode("utf-8")
    random.seed(unique_seed)
    return random.uniform(lb, ub) if random.random() > sparsity else random.uniform(-NEAR_ZERO, NEAR_ZERO)

UUS_agents = [train_qtable(env_name = env_name, episodes = NUM_EPS_TRAIN_R, 
                           reward_function = lambda *args: det_rand_terminal(*args)) for _ in range(NUM_TRAIN_R_FUNCS)]

In [33]:
### Turn the state and action space of Taxi-v3 into a graph

from collections import defaultdict
taxi_env = gym.make("Taxi-v3")
taxi_env.reset()
# Initialize containers for graph data
edges = defaultdict(list)
edge_attr = defaultdict(list)

# A helper function to encode the state into a single number (node index)
def state_to_node(taxi_row, taxi_col, pass_loc, dest_idx):
    # This encoding assumes specific knowledge about the Taxi-v3 state space size
    return taxi_row * 100 + taxi_col * 20 + pass_loc * 4 + dest_idx
    # max = 4 * 100 + 4 * 20 + 4 * 4 + 3 = 400 + 80 + 16 + 3 = 499

# Iterate through all possible states and actions to construct the graph
for taxi_row in range(5):
    for taxi_col in range(5):
        for pass_loc in range(5):  # 4 locations + 1 for 'in taxi'
            for dest_idx in range(4):
                current_state = state_to_node(taxi_row, taxi_col, pass_loc, dest_idx)
                for action in range(taxi_env.action_space.n):
                    # Set the environment to the current state
                    taxi_env.unwrapped.s = current_state
                    # Take action and observe the next state and reward
                    next_state, reward, done, _ = taxi_env.step(action)
                    # Add edge from current state to next state
                    edges[current_state].append(next_state)
                    # Optionally, use rewards as edge attributes
                    # edge_attr[(current_state, next_state)].append(reward)
                    taxi_env.reset()


# Convert edges and edge attributes to tensors
edge_index = []
for src, dsts in edges.items():
    for dst in dsts:
        edge_index.append([src, dst])
edge_index = torch.tensor(edge_index).t().contiguous()
edge_index, edge_index.shape

(tensor([[  0,   0,   0,  ..., 499, 499, 499],
         [100,   0,  20,  ..., 479, 499, 499]]),
 torch.Size([2, 3000]))

In [72]:
def greedy_policy(q_table):
    return torch.tensor(np.argmax(q_table, axis=1).reshape(-1, 1).astype(np.float32))
def random_policy(state_dim):
    return torch.randint(0, 6, (state_dim, 1)).float()
def prep_qtable(q_table):
    return torch.tensor(q_table, dtype=torch.float32)

dataset1 = [Data(x = prep_qtable(agent.q_table), edge_index = edge_index, y = 1) for agent in USS_agents]
dataset2 = [Data(x = prep_qtable(agent.q_table), edge_index = edge_index, y = 0) for agent in URS_agents]
#dataset2 = [Data(x = random_policy(agent.q_table.shape[0]), edge_index = edge_index, y = 0) for agent in UPS_agents]
# ^ random_policy = UPS sampling
print(dataset1[0].x.shape)

train_data, test_data, num_node_features = generate_data(dataset1, dataset2)
model = GraphLevelGCN(num_node_features)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1, weight_decay=5e-4)
train_classifier(model, criterion, optimizer, train_data, test_data)

torch.Size([500, 6])
Epoch 1: Average Train Loss: 0.06216179677363045, Average Test Loss: 0.00010359175885241712
Epoch 2: Average Train Loss: 0.0004164183939451505, Average Test Loss: 0.0006073502612707671
Epoch 3: Average Train Loss: 0.0007531430309427378, Average Test Loss: 0.0007604949561937247
Epoch 4: Average Train Loss: 0.0008634724525109049, Average Test Loss: 0.0006281233057961799
Early stopping at epoch 4


In [73]:
print(torch.transpose(dataset1[0].x, 0, -1)[0:10, 0:10]) # Example greedy policies
print(torch.transpose(dataset2[0].x, 0, -1)[0:10, 0:10])
print(model.forward(dataset1[0]))
print(model.forward(dataset2[0]))

tensor([[ 0.0000e+00, -8.5504e-12,  0.0000e+00, -4.1088e-02,  0.0000e+00,
          0.0000e+00,  0.0000e+00, -6.6273e-13,  5.3457e-12,  6.9077e-11],
        [ 0.0000e+00, -3.9739e-13,  2.7565e-09,  7.3998e-10,  7.7897e-11,
          0.0000e+00,  3.4917e-12,  6.0604e-09,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00, -7.2252e-12,  0.0000e+00,  0.0000e+00,  3.4676e-09,
          0.0000e+00, -2.4605e-11,  0.0000e+00,  6.7238e-11,  5.3296e-11],
        [ 0.0000e+00, -7.2347e-12,  8.4538e-12,  0.0000e+00,  8.8427e-10,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  6.9332e-12,  7.6284e-12],
        [ 0.0000e+00, -7.9438e-12,  0.0000e+00,  0.0000e+00,  3.3650e-10,
          0.0000e+00,  0.0000e+00,  1.0530e-11, -5.9638e-12,  1.0271e-11],
        [ 0.0000e+00,  0.0000e+00, -2.7835e-10,  0.0000e+00,  1.0513e-09,
          0.0000e+00,  0.0000e+00,  3.0895e-09, -3.5969e-12,  5.3653e-12]])
tensor([[ 0.0000e+00, -8.6664e-03, -2.8030e-03, -2.4628e-02,  0.0000e+00,
          0.0000e+00,  3.0579e+

In [74]:
taxi_model = train_qtable(env_name = "Taxi-v3", episodes = 3000, verbose = True, print_every = 1000, 
                          return_reward = False)
taxi_data = Data(x = torch.tensor(prep_qtable(taxi_model.q_table)), edge_index = edge_index)
print(model.forward(taxi_data))

Episode: 1000, Average total reward: -423.2692692692693, Epsilon: 0.14
Episode: 2000, Average total reward: -174.55655655655656, Epsilon: 0.03
Episode: 3000, Average total reward: -108.48448448448448, Epsilon: 0.01
tensor([[0.9863]], grad_fn=<SigmoidBackward0>)


  taxi_data = Data(x = torch.tensor(prep_qtable(taxi_model.q_table)), edge_index = edge_index)


In [75]:
taxi_model.q_table

array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ],
       [-2.5970681 , -2.58678979, -2.59171511, -2.59413361, -2.59139141,
        -3.18260675],
       [-1.62748602, -1.63384201, -1.64086321, -1.63292275, -1.6048012 ,
        -3.79311952],
       ...,
       [-0.8030423 , -0.74959066, -0.79656961, -0.80100224, -1.95104146,
        -1.15839785],
       [-2.05603415, -2.05609945, -2.04796067, -2.05448177, -2.06334143,
        -2.58135788],
       [-0.05834885, -0.05008236, -0.03969999,  0.7222398 , -0.297306  ,
        -0.1       ]])

In [53]:
### Find optimal policy by MCTS and test coherence classifier on it
class Node:
    def __init__(self, state, parent=None, action = None):
        self.state = state
        self.parent = parent
        self.action = action # Action taken to reach this state
        self.children = []
        self.visits = 0
        self.value = 0

    def add_child(self, child):
        self.children.append(child)

    def update(self, reward):
        self.visits += 1
        self.value += reward

    def is_fully_expanded(self):
        return len(self.children) == env.action_space.n

    def best_child(self, c_param=1.4):
        choices_weights = [(child.value / child.visits) + c_param * np.sqrt((2 * np.log(self.visits) / child.visits)) for child in self.children]
        return self.children[np.argmax(choices_weights)]

def selection(node):
    while not node.is_fully_expanded():
        if not node.children:
            return expansion(node)
        else:
            node = node.best_child()
    return node

def expansion(node):
    tried_children = [child.state for child in node.children]
    new_state = None
    # Attempt to find an untried action/state
    for _ in range(env.action_space.n):
        action = env.action_space.sample()
        env.reset()  # Reset to the current node's state
        env.env.s = node.state  # This uses internal knowledge about gym's environment state setting
        next_state, _, _, _ = env.step(action)
        if next_state not in tried_children:
            new_state = next_state
            break
    if new_state is not None:
        child_node = Node(new_state, parent=node, action=action)
        node.add_child(child_node)
        return child_node
    return node  # This case handles leaf nodes

def simulation(node):
    current_state = node.state
    done = False
    total_reward = 0

    for _ in range(100):  # Limit the number of steps to prevent infinite loop
        if done:
            break
        action = env.action_space.sample()
        env.reset()  # Reset to the current simulation node's state
        env.env.s = current_state
        next_state, reward, done, _ = env.step(action)
        total_reward += reward
        current_state = next_state

    return total_reward

def backpropagation(node, reward):
    while node is not None:
        node.update(reward)
        node = node.parent

def mcts(root, iterations=1000):
    for _ in range(iterations):
        leaf = selection(root)
        reward = simulation(leaf)
        backpropagation(leaf, reward)

# Initialize the Gym environment
env = gym.make(env_name)

# MCTS
initial_state = env.reset()
root_node = Node(initial_state)
mcts(root_node, iterations=1000)

# After running MCTS, your root_node will have an approximated policy.
# You can select actions by traversing the tree from the root using best_child()

In [54]:
### Test MCTS

def choose_action(node):
    # Choose the child with the highest visit count
    if node.children:
        return max(node.children, key=lambda child: child.visits).action
    else:
        return None

def simulate_episode_from_root(env, root_node):
    total_reward = 0
    done = False
    current_node = root_node
    env.reset()
    env.env.s = current_node.state
    
    while not done and current_node is not None:
        action = choose_action(current_node)
        if action is None:
            # No more information in the tree; choose random action
            action = env.action_space.sample()
        
        next_state, reward, done, _ = env.step(action)  # Execute the chosen action
        total_reward += reward
        
        # Move to the next node in the tree, if it exists
        next_node = None
        for child in current_node.children:
            if child.action == action:
                next_node = child
                break
        current_node = next_node

    return total_reward


# Test the policy derived from the MCTS root node
env = gym.make('Taxi-v3')
average_reward = np.mean([simulate_episode_from_root(env, root_node) for _ in range(100)])
print(f"Average Reward from the MCTS policy: {average_reward}")


Average Reward from the MCTS policy: -772.74


- GNNs over the environment work really well, even on USS/URS and UUS/URS (identifying "sparsity")
    - Real Taxi agents are graded strongly towards USS and UUS end
    - Looking at the q-tables, there are some noticible differences (e.g. USS q-tables tend to have lower magnitude)
    - GNNs don't work yet on USS/URS when only the policy is passed in
    - Started working once I put the fix in of changing the reward function of the environment *for each state*