In [1]:
import gym
import numpy as np
import random
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch_geometric.nn.conv import GCNConv
import math

Setup:
- Starting point: just try to train classifier on RL policies

In [2]:
### DQN implementation

# Define the Q-network
class DQN(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super(DQN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

    def forward(self, x):
        return self.fc(x)
    
    def get_weights(self):
        return self.state_dict()

# Experience Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        state = np.expand_dims(state, 0)
        next_state = np.expand_dims(next_state, 0)
            
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.concatenate(state), action, reward, np.concatenate(next_state), done

    def __len__(self):
        return len(self.buffer)

# DQN Agent
class DQNAgent:
    def __init__(self, state_dim, action_dim, hidden_dim=64, lr=1e-2, batch_size=64, gamma=0.99, replay_size=1000):
        self.model = DQN(state_dim, action_dim, hidden_dim)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.replay_buffer = ReplayBuffer(replay_size)
        self.batch_size = batch_size
        self.gamma = gamma
        self.action_dim = action_dim
    
    def update(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        state, action, reward, next_state, done = self.replay_buffer.sample(self.batch_size)
        
        state = torch.FloatTensor(state)
        next_state = torch.FloatTensor(next_state)
        if len(state.shape) == 1:
            state = state.reshape(-1, 1)
        if len(next_state.shape) == 1:
            next_state = next_state.reshape(-1, 1)
        action = torch.LongTensor(action)
        reward = torch.FloatTensor(reward)
        done = torch.FloatTensor(done)

        q_values = self.model.forward(state)
        next_q_values = self.model.forward(next_state)

        # state = state.T
        # next_state = next_state.T
        
        q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
        next_q_value = next_q_values.max(1)[0]
        expected_q_value = reward + self.gamma * next_q_value * (1 - done)
        
        loss = nn.MSELoss()(q_value, expected_q_value.detach())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def act(self, state, epsilon):
        if random.random() > epsilon:
            state = torch.FloatTensor(np.expand_dims(state, 0))
            q_value = self.model(state)
            action = q_value.max(-1)[1].item()
        else:
            action = random.randrange(self.action_dim)
        return action
    
class QTableAgent:
    def __init__(self, state_dim, action_dim, lr=1e-2, gamma=0.99):
        self.q_table = np.zeros((state_dim, action_dim))
        self.lr = lr
        self.gamma = gamma
        self.action_dim = action_dim
    
    def update(self, state, action, reward, next_state, done):
        q_value = self.q_table[state, action]
        next_q_value = np.max(self.q_table[next_state])
        expected_q_value = reward + self.gamma * next_q_value * (1 - done)
        
        self.q_table[state, action] += self.lr * (expected_q_value - q_value)
    
    def act(self, state, epsilon):
        if random.random() > epsilon:
            action = np.argmax(self.q_table[state])
        else:
            action = random.randrange(self.action_dim)
        return action

In [3]:
### train_dqn
NUM_NON_ZERO_REWARDS = 0
def one_hot_state(state, env):
    state_arr = np.zeros(env.observation_space.n)
    state_arr[state] = 1
    return state_arr

def train_dqn(env_name="CartPole-v1", episodes=500, epsilon_start=1.0, epsilon_final=0.01, 
              epsilon_decay=500, reward_function = None, verbose = False, return_reward = False, 
              print_every=50, **kwargs):
    """
    Train a DQN agent on the specified environment.
    
    Args:
        env_name: str
            Name of the environment to train the agent on.
        episodes: int
            Number of episodes to train the agent for.
        epsilon_start: float
            Initial epsilon value for epsilon-greedy action selection.
        epsilon_final: float
            Final epsilon value for epsilon-greedy action selection.
        epsilon_decay: float
            Decay rate for epsilon.
        reward_function: function
            Optional reward function to use for training.
        verbose: bool
            Whether to print training progress.

    Returns:
        DQNAgent: trained DQN agent. 
    """
    global NUM_NON_ZERO_REWARDS
    env = gym.make(env_name)
    if len(env.observation_space.shape) == 0:
        state_dim = env.observation_space.n
    else:
        state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    agent = DQNAgent(state_dim, action_dim, **kwargs)
    
    epsilon_by_frame = lambda frame_idx: epsilon_final + (epsilon_start - epsilon_final) * np.exp(-1. * frame_idx / epsilon_decay)
    
    rewards = np.zeros(episodes) 
    is_state_discrete = hasattr(env.observation_space, 'n')
    for episode in range(episodes):
        state = env.reset() # Reset the environment, reward
        if is_state_discrete and state_dim == env.observation_space.n:
            state = one_hot_state(state, env)
        episode_reward = 0
        while True:
            epsilon = epsilon_by_frame(episode)
            # One-hot encode the state
            action = agent.act(state, epsilon)
            next_state, reward, done, _ = env.step(action)
            if is_state_discrete and state_dim == env.observation_space.n:
                next_state = one_hot_state(next_state, env)

            if reward_function: #custom reward function
                reward = reward_function(done, state, action, next_state)
            NUM_NON_ZERO_REWARDS += 0 if math.isclose(reward, 0) else 1
            
            agent.replay_buffer.push(state, action, reward, next_state, done)
            state = next_state
            episode_reward += reward
            
            agent.update()
            
            if done:
                break
            # print(f"Episode: {episode+1}, Total reward: {episode_reward}, Epsilon: {epsilon:.2f}")

        rewards[episode] = episode_reward
        # Optional: Render the environment to visualize training progress
        if verbose and episode % print_every == print_every - 1:
        #     render_env(env, agent)
            print(f"Episode: {episode+1}, Average total reward: {np.average(rewards[episode - print_every + 1 : episode])}, Epsilon: {epsilon:.2f}")

    env.close()
    return agent if not return_reward else (agent, rewards)

# Optional: Function to render the environment with the current policy
def render_env(env, agent):
    state = env.reset()
    done = False
    while not done:
        action = agent.act(state, 0)  # Using 0 epsilon for greedy action selection
        # print(env.step(action))
        next_state, reward, done, _ = env.step(action)
        env.render()
        state = next_state

In [4]:
def train_qtable(env_name="CartPole-v1", episodes=500, epsilon_start=1.0, epsilon_final=0.01, 
              epsilon_decay=500, reward_function = None, verbose = False, return_reward = False, 
              print_every=50, **kwargs):
    """
    Train a Q-table agent on the specified environment."""
    global NUM_NON_ZERO_REWARDS
    env = gym.make(env_name)
    if len(env.observation_space.shape) == 0:
        state_dim = env.observation_space.n
    else:
        state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    agent = QTableAgent(state_dim, action_dim, **kwargs)

    rewards = np.zeros(episodes)
    epsilon_by_frame = lambda frame_idx: epsilon_final + (epsilon_start - epsilon_final) * np.exp(-1. * frame_idx / epsilon_decay)
    for episode in range(episodes):
        state = env.reset()
        episode_reward = 0
        while True:
            epsilon = epsilon_by_frame(episode)
            action = agent.act(state, epsilon)
            next_state, reward, done, _ = env.step(action)
            if reward_function:
                reward = reward_function(done, state, action, next_state)

            agent.update(state, action, reward, next_state, done)
            state = next_state
            episode_reward += reward
            if done:
                break

        rewards[episode] = episode_reward
        if verbose and episode % print_every == print_every - 1:
            print(f"Episode: {episode+1}, Average total reward: {np.average(rewards[episode - print_every + 1 : episode])}, Epsilon: {epsilon:.2f}")
        
    env.close()
    return agent if not return_reward else (agent, rewards)
    

In [5]:
### test_dqn
NEAR_ZERO = 1e-9
def test_dqn(env, agent, episodes=10, reward_function=None, verbose = False):
    print(f"Maximum reward: {env.spec.reward_threshold}")
    average_value = 0
    for episode in range(episodes):
        # if episode == 0:
        #     render_env(env, agent)
        state = env.reset()
        if len(env.observation_space.shape) == 0:
            state = one_hot_state(state, env)
        episode_reward = 0
        done = False
        while not done:
            action = agent.act(state, 0)  # Using 0 epsilon for greedy action selection
            next_state, reward, done, _ = env.step(action)
            if len(env.observation_space.shape) == 0:
                next_state = one_hot_state(next_state, env)
            if reward_function:
                reward = reward_function(done, state, action, next_state)
            episode_reward += reward
            state = next_state
        if verbose:
            print(f"Episode: {episode+1}, Total reward: {episode_reward}")
        average_value += episode_reward
    average_value /= episodes
    print(f"Average reward: {average_value}")
    

In [6]:
def test_qtable(env, agent, episodes=10, reward_function=None, verbose = False):
    """
    Test a Q-table agent on the specified environment.
    (This is basically test_dqn but without the one-hot encoding.)
    """
    print(f"Maximum reward: {env.spec.reward_threshold}")
    average_value = 0
    for episode in range(episodes):
        state = env.reset()
        episode_reward = 0
        done = False
        while not done:
            action = agent.act(state, 0)  # Using 0 epsilon for greedy action selection
            next_state, reward, done, _ = env.step(action)
            if reward_function:
                reward = reward_function(done, state, action, next_state)
            episode_reward += reward
            state = next_state
        if verbose:
            print(f"Episode: {episode+1}, Total reward: {episode_reward}")
        average_value += episode_reward
    average_value /= episodes
    print(f"Average reward: {average_value}")

In [10]:
env_name = "Taxi-v3"
agent, rewards = train_qtable(env_name = env_name, episodes = 10000, verbose = True, return_reward = True,
                           epsilon_decay=10, lr=0.1, gamma=0.9, print_every=1000, 
                           reward_function = lambda done, *args: 1 if done else -1)  
# rewards

Episode: 1000, Average total reward: -147.58058058058057, Epsilon: 0.01
Episode: 2000, Average total reward: -59.01601601601602, Epsilon: 0.01
Episode: 3000, Average total reward: -26.87887887887888, Epsilon: 0.01
Episode: 4000, Average total reward: -18.01801801801802, Epsilon: 0.01
Episode: 5000, Average total reward: -15.036036036036036, Epsilon: 0.01
Episode: 6000, Average total reward: -13.521521521521521, Epsilon: 0.01
Episode: 7000, Average total reward: -12.774774774774775, Epsilon: 0.01
Episode: 8000, Average total reward: -12.16916916916917, Epsilon: 0.01
Episode: 9000, Average total reward: -11.883883883883884, Epsilon: 0.01
Episode: 10000, Average total reward: -11.901901901901901, Epsilon: 0.01


In [11]:
test_qtable(gym.make(env_name), agent, episodes = 100)
# agent.model.get_weights()

Maximum reward: 8
Average reward: -76.4


In [27]:
### NN to GCN Data conversion

#agent.model.get_weights()

# Define a simple GCN model
from torch_geometric.data import Data
class GCN(torch.nn.Module):
    def __init__(self, data):
        super(GCN, self).__init__()
        # Define the GCN layers
        self.conv1 = GCNConv(data.num_node_features, 4)  # Input features to hidden
        self.conv2 = GCNConv(4, 2)  # Hidden to output features
        self.data = data

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        # Pass data through the first GCN layer, then apply ReLU
        x = torch.relu(self.conv1(x, edge_index))
        # Pass data through the second GCN layer
        x = self.conv2(x, edge_index)
        return x
    
    
def nn_to_data(model: nn.Module) -> Data:
    edges = []

    # Counter for global neuron index
    idx = 0

    # Iterate over each layer in the network
    base = next(model.children())
    if isinstance(base, nn.Sequential):
        layers = list(base.children())
        layers2 = list(base.children())
    else:
        layers = list(model.children()) # iterator over the layers of the model
        layers2 = list(model.children())
    
    num_nodes = layers2[0].weight.shape[1] + sum([layer.weight.shape[0] for layer in layers2 if isinstance(layer, nn.Linear)])
    num_node_features = num_nodes
    node_features = torch.zeros(num_nodes, num_node_features)
    # shape = (num_nodes, num_node_features), where the node features are the bias of each node
    # and the weights of the edges to each node (zero if there is no edge)

    for layer in layers:
        if isinstance(layer, nn.Linear):
            # Update edges based on the weight matrix
            input_dim = layer.weight.shape[1]
            output_dim = layer.weight.shape[0]
            for i in range(input_dim):  # Input neurons (e.g. 4)
                for j in range(output_dim):  # Output neurons (e.g. 64)
                    edges.append((idx + i, idx + input_dim + j))
            
            # Update node features (e.g., biases)
            biases = torch.tensor(layer.bias.detach().numpy())
            edge_weights = torch.tensor(layer.weight.detach().numpy().T)
            node_features[idx + input_dim:idx + input_dim + output_dim, 0] = biases
            node_features[idx:idx + input_dim, 1+idx:1+idx+output_dim] = edge_weights
            node_features[idx + input_dim:idx + input_dim + output_dim, 1+idx:1+idx+input_dim] = edge_weights.T
            
            # Update the global neuron index
            idx += input_dim

    # Convert lists to PyTorch tensors
    num_nonzero = [np.count_nonzero(node_features[i]) for i in range(node_features.shape[0])]
    # print(num_nonzero)
    row_mean, row_median, row_var = torch.mean(node_features[:, 1:], dim=1), torch.median(node_features[:, 1:], dim=1)[0], torch.var(node_features[:, 1:], dim=1)
    x = torch.stack([node_features[:, 0], row_mean, row_median, row_var]).T
    # print(x.shape)
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    return Data(x=x, edge_index=edge_index)

agent = train_dqn(env_name = "CartPole-v1", episodes = 1000, verbose = True, return_reward = False)
data = nn_to_data(agent.model)
gcn = GCN(data)
# data.x.shape, data.edge_index.shape
# print(data.x)

#Debug
out_of_bounds = data.edge_index >= data.x.shape[0]
if out_of_bounds.any():
    print("Out-of-bounds indices found at locations:")
    print(data.edge_index[:, out_of_bounds.any(dim=0)])

Episode: 50, Average total reward: 25.20408163265306, Epsilon: 0.91
Episode: 100, Average total reward: 30.918367346938776, Epsilon: 0.82
Episode: 150, Average total reward: 36.244897959183675, Epsilon: 0.74
Episode: 200, Average total reward: 51.53061224489796, Epsilon: 0.67
Episode: 250, Average total reward: 65.10204081632654, Epsilon: 0.61
Episode: 300, Average total reward: 69.73469387755102, Epsilon: 0.55
Episode: 350, Average total reward: 99.59183673469387, Epsilon: 0.50
Episode: 400, Average total reward: 167.24489795918367, Epsilon: 0.46
Episode: 450, Average total reward: 158.81632653061226, Epsilon: 0.41
Episode: 500, Average total reward: 99.95918367346938, Epsilon: 0.37
Episode: 550, Average total reward: 167.53061224489795, Epsilon: 0.34
Episode: 600, Average total reward: 105.61224489795919, Epsilon: 0.31
Episode: 650, Average total reward: 169.55102040816325, Epsilon: 0.28
Episode: 700, Average total reward: 80.6734693877551, Epsilon: 0.25
Episode: 750, Average total r

In [25]:
# Dataset generation
env_name = "CartPole-v1"
env = gym.make(env_name)
NEAR_ZERO = 1e-9
NUM_REWARD_CALLS = 0
NUM_NON_ZERO_REWARDS = 0
def deterministic_random(*args, lb = -1, ub = 1, sparsity = 0.0, continuous = False):
    """
    Create a deterministic random number generator for a given set of arguments.
    Used to generate deterministic reward functions for the coherence classifier.
    [Edit 4/3/24: adapted to continuous state space]"""
    global NUM_REWARD_CALLS
    NUM_REWARD_CALLS += 1
    unique_seed = f"{args}".encode("utf-8")
    random.seed(unique_seed)
    return random.uniform(lb, ub) if random.random() > sparsity else random.uniform(-NEAR_ZERO, NEAR_ZERO)

NUM_TRAIN_R_FUNCS = 50
NUM_EPS_TRAIN_R = 50
URS_r_funcs = [lambda *args: deterministic_random(args) for _ in range(NUM_TRAIN_R_FUNCS)]
URS_agents = [train_dqn(env_name = env_name, 
                        episodes=NUM_EPS_TRAIN_R, reward_function=r_func) for r_func in URS_r_funcs]
USS_r_funcs = [lambda *args: deterministic_random(args, sparsity=0.99) for _ in range(NUM_TRAIN_R_FUNCS)]
print(f"Number of reward function calls: {NUM_REWARD_CALLS}")
print(f"Number of non-zero rewards: {NUM_NON_ZERO_REWARDS}")
USS_agents = [train_dqn(env_name = env_name, 
                        episodes=NUM_EPS_TRAIN_R, reward_function=r_func) for r_func in USS_r_funcs]

Number of reward function calls: 59020
Number of non-zero rewards: 59020


In [11]:
# Test if deterministic_random is deterministic and has the correct sparsity
assert deterministic_random(1, 2, 3, 4) == deterministic_random(1, 2, 3, 4)
assert not deterministic_random(1, 2, 3, 4) == deterministic_random(1, 2, 3, 6)
[deterministic_random(1, 2, 3, i, sparsity = 0.5) for i in range(10)]

[0.6014137224608205,
 5.734868810378968e-10,
 0.18947200717913826,
 -0.11464719428521586,
 1.9375194864306798e-11,
 3.1131227593489704e-10,
 -7.023178277046693e-10,
 2.965355797951794e-10,
 0.41831271768541045,
 -3.207247699354683e-10]

In [12]:
# Test when do USS agents have non-zero rewards
env_name = "CartPole-v1"
USS_test_r_func = lambda *args: deterministic_random(args, sparsity=0.0)
assert USS_test_r_func(42) == USS_test_r_func(42)
USS_test_agent = train_dqn(env_name = env_name, episodes=500, reward_function=USS_test_r_func, 
                           verbose = True)
# Epsilon measuring how much the agent is exploring

Episode: 50, Average total reward: -0.08030391727204088, Epsilon: 0.91
Episode: 100, Average total reward: 1.1898305988496911, Epsilon: 0.82
Episode: 150, Average total reward: 0.5990051160714706, Epsilon: 0.74
Episode: 200, Average total reward: -0.20225874692314827, Epsilon: 0.67
Episode: 250, Average total reward: -0.3040020801607395, Epsilon: 0.61
Episode: 300, Average total reward: 0.088487752426798, Epsilon: 0.55
Episode: 350, Average total reward: 0.17828783475711782, Epsilon: 0.50
Episode: 400, Average total reward: 0.22612443296435566, Epsilon: 0.46
Episode: 450, Average total reward: 0.48808380894686876, Epsilon: 0.41
Episode: 500, Average total reward: -0.8578852051258352, Epsilon: 0.37


In [13]:
# epsilon_final, epsilon_start, epsilon_decay = 0.01, 1.0, 500
# [epsilon_final + (epsilon_start - epsilon_final) * np.exp(-1. * frame_idx / epsilon_decay) for frame_idx in range(500)]
test_dqn(gym.make(env_name), USS_test_agent, reward_function=USS_test_r_func)

Maximum reward: 475.0
Average reward: -3.051088483330308


In [28]:
### Define and train GCN classifier on NNs

def get_state_shape(env):
    return 1 if len(env.observation_space.shape) == 0 else env.observation_space.shape[0]
def get_state_size(env):
    return env.observation_space.n if len(env.observation_space.shape) == 0 else env.observation_space.shape[0]
UPS_agents = [DQNAgent(get_state_size(env), env.action_space.n) for _ in range(NUM_TRAIN_R_FUNCS)]

import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool, GCNConv, GATConv

class GraphLevelGCN(torch.nn.Module):
    def __init__(self, num_node_features):
        super(GraphLevelGCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 16)
        self.conv2 = GCNConv(16, 16)
        self.linear = torch.nn.Linear(16, 1)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        # edge_weights = data.edge_attr
        
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        
        # Aggregate node features to graph-level features
        x = global_mean_pool(x, batch)
        
        # Make a binary classification prediction
        x = self.linear(x)
        return torch.sigmoid(x)

class GATGraphLevelBinary(torch.nn.Module):
    def __init__(self, num_node_features):
        super(GATGraphLevelBinary, self).__init__()
        self.conv1 = GATConv(num_node_features, 8, heads=8, dropout=0.6)
        # Increase the number of output features from the first GAT layer
        self.conv2 = GATConv(8 * 8, 16, heads=1, concat=False, dropout=0.6)
        # Additional GAT layer for richer node representations
        self.linear = torch.nn.Linear(16, 1)
        # Final linear layer to produce a graph-level output

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        
        x = global_mean_pool(x, batch)  # Aggregate node features to graph-level
        x = self.linear(x)
        return torch.sigmoid(x)  # Sigmoid activation function for binary classification

# Training loop
USS_data = [nn_to_data(agent.model) for agent in USS_agents]
URS_data = [nn_to_data(agent.model) for agent in URS_agents]
print(URS_data[0].x.shape)
UPS_data = [nn_to_data(agent.model) for agent in UPS_agents]
assert URS_data[0].x.shape == UPS_data[0].x.shape

torch.Size([134, 4])


In [163]:
# Binary classification between two datasets
dataset1 = USS_data
dataset2 = URS_data
def generate_data(dataset1, dataset2):
    indices = np.random.permutation(len(dataset1) + len(dataset2))
    data = [dataset1[i] if i < len(dataset1) else dataset2[i - len(dataset1)] for i in indices]
    for i in range(len(data)):
        data[i].y = 1.0 if indices[i] < len(dataset1) else 0.0 # Binary labels for each node; 1 = URS, 0 = UPS
        # Hence roughly speaking, 1 = more coherent, 0 = less coherent

    train_data_ratio = 0.8
    train_data, test_data = data[:int(train_data_ratio * len(data))], data[int(train_data_ratio * len(data)):]
    num_node_features = data[0].x.shape[1] # Number of features for each node
    return train_data, test_data, num_node_features

train_data, test_data, num_node_features = generate_data(dataset1, dataset2)
# Loss and optimizer
model = GraphLevelGCN(num_node_features)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def train_classifier(model, criterion, optimizer, train_data, test_data, epochs = 40, patience = 3, 
                     epochs_without_improvement = 0, best_loss = float('inf')):
    for epoch in range(epochs):
        avg_train_loss = 0
        for datapt in train_data:
            model.train()
            optimizer.zero_grad()

            # print(f"datapt.x shape: {datapt.x.shape}")  # Should be [num_nodes, num_node_features]
            # print(f"datapt.edge_index shape: {datapt.edge_index.shape}")  # Should be [2, num_edges]
            out = model.forward(datapt)
            # print(out.size())
            # print(torch.tensor([[datapt.y]]).size())
            loss = criterion(out, torch.tensor([[datapt.y]]))  # Adjust shape as necessary
            loss.backward()
            optimizer.step()
            avg_train_loss += loss.item()
        avg_train_loss /= len(train_data)

        avg_test_loss = 0
        for datapt in test_data:
            model.eval()
            with torch.no_grad():
                out = model.forward(datapt)
                loss = criterion(out, torch.tensor([[datapt.y]]))
                avg_test_loss += loss.item()
        avg_test_loss /= len(test_data)
        
        print(f'Epoch {epoch+1}: Average Train Loss: {avg_train_loss}, Average Test Loss: {avg_test_loss}')
        
        # Early Stopping
        if avg_test_loss < best_loss:
            best_loss = avg_test_loss
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= patience:
                print(f'Early stopping at epoch {epoch+1}')
                break
        
    metrics = {'train_loss': avg_train_loss, 'test_loss': avg_test_loss}
    return metrics

train_classifier(model, criterion, optimizer, train_data, test_data, epochs = 100, patience = 5)

Epoch 1: Average Train Loss: 0.7368341466411948, Average Test Loss: 0.6454022590070962
Epoch 2: Average Train Loss: 0.6828794175758958, Average Test Loss: 0.64594763207715
Epoch 3: Average Train Loss: 0.6786422868259251, Average Test Loss: 0.644322702777572
Epoch 4: Average Train Loss: 0.6756062641274184, Average Test Loss: 0.642260631674435
Epoch 5: Average Train Loss: 0.6731930308509618, Average Test Loss: 0.6402062258508522
Epoch 6: Average Train Loss: 0.6713151082862169, Average Test Loss: 0.6383412880008109
Epoch 7: Average Train Loss: 0.6697166946483776, Average Test Loss: 0.636723642249126
Epoch 8: Average Train Loss: 0.6685101984301582, Average Test Loss: 0.6353228765859967
Epoch 9: Average Train Loss: 0.6675114025827498, Average Test Loss: 0.6340729216404725
Epoch 10: Average Train Loss: 0.6667067179689183, Average Test Loss: 0.6329930205072742
Epoch 11: Average Train Loss: 0.6660635713487864, Average Test Loss: 0.6320858108156244
Epoch 12: Average Train Loss: 0.66557614412158

{'train_loss': 0.6433332148531917, 'test_loss': 0.6186228856175149}

In [164]:
# Test GCN model on a "more powerful" NN
print(model.forward(dataset1[0]))
print(model.forward(dataset2[0]))
powerful_models = [nn_to_data(train_dqn(env_name = env_name, episodes = 5 * i).model) 
                   for i in [1, 3, 10]]
print([model.forward(data) for data in powerful_models])

tensor([[0.6210]], grad_fn=<SigmoidBackward0>)
tensor([[0.0036]], grad_fn=<SigmoidBackward0>)
[tensor([[0.1065]], grad_fn=<SigmoidBackward0>), tensor([[0.1559]], grad_fn=<SigmoidBackward0>), tensor([[0.2046]], grad_fn=<SigmoidBackward0>)]


- The classifier training process is finicky -- sometimes it overfits, sometimes it underfits -- but sometimes can reach very low loss (< 0.002)
- Even weak classifiers classify powerful models (a.k.a. agents with >15 episodes in CartPole) as having P(URS) = 1, corresponding to coherence ~ $\infty$
- P(USS) / P(URS) is still having trouble as a metric; seems extremely difficult to detect differences between USS and URS-generated policies here with current methods
    - We will need some kind of "more advanced" coherence metric to distinguish more advanced policies; TODO: implement UUS somehow
- Adding node weights to every other node to the features passed into the GCN (such that, in CartPole, the data matrix has shape (134, 134) instead of (134, 1)) makes the GCN much worse, probably because of higher dimensionality
    - Using attention in the GNN does not help, and in fact actively overfits when using (134, 1) data
- Even with sparsity = 0.999, USS is still hard to distinguish
- For simpler discrete environments, maybe a Q-table is enough to solve the problem
- Takes >= 500 episodes, small epsilon to effectively learn DQN policy


In [33]:
# Now classifying q-table agents
env_name = "Taxi-v3"
NUM_EPS_TRAIN_R = 1000
NUM_TRAIN_R_FUNCS = 50
NUM_REWARD_CALLS = 0
env = gym.make(env_name)
def deterministic_random(*args, lb = -1, ub = 1, sparsity = 0.0, continuous = False):
    """
    Create a deterministic random number generator for a given set of arguments.
    Used to generate deterministic reward functions for the coherence classifier.
    [Edit 4/3/24: adapted to continuous state space]"""
    global NUM_REWARD_CALLS
    NUM_REWARD_CALLS += 1
    unique_seed = f"{args}".encode("utf-8")
    random.seed(unique_seed)
    return random.uniform(lb, ub) if random.random() > sparsity else random.uniform(-NEAR_ZERO, NEAR_ZERO)
def get_state_shape(env):
    return 1 if len(env.observation_space.shape) == 0 else env.observation_space.shape[0]
def get_state_size(env):
    return env.observation_space.n if len(env.observation_space.shape) == 0 else env.observation_space.shape[0]

URS_r_funcs = [lambda *args: deterministic_random(args) for _ in range(NUM_TRAIN_R_FUNCS)]
URS_agents = [train_qtable(env_name = env_name, episodes=NUM_EPS_TRAIN_R, 
                           reward_function = r_func) for r_func in URS_r_funcs]
print("Halfway there!")
USS_r_funcs = [lambda *args: deterministic_random(args, sparsity=0.99) for _ in range(NUM_TRAIN_R_FUNCS)]
USS_agents = [train_qtable(env_name = env_name, episodes=NUM_EPS_TRAIN_R,
                            reward_function = r_func) for r_func in USS_r_funcs]
UPS_agents = [QTableAgent(get_state_size(env), env.action_space.n) for _ in range(NUM_TRAIN_R_FUNCS)]

# The Q-Table is already one-hot encoded, so we don't need to convert it to a Data object

Halfway there!


In [34]:
# Test ground
print([USS_r_funcs[0](i) for i in range(10)])
test_USS_agent = train_qtable(env_name = env_name, episodes = 50, verbose=True, epsilon_decay = 100, 
                              lr = 0.01, gamma = 0.9, reward_function = USS_r_funcs[0])
test_qtable(gym.make(env_name), test_USS_agent, episodes = 100, reward_function = USS_r_funcs[0])
test_UPS_agent = QTableAgent(get_state_size(env), env.action_space.n)
test_qtable(gym.make(env_name), test_UPS_agent, episodes = 100, reward_function = USS_r_funcs[0])

[6.879610979477867e-10, -8.537837971656553e-10, 1.3453720557822044e-10, -6.91582481382308e-10, 5.834328244211258e-10, 2.689469001408139e-10, -7.747733415111116e-10, 3.5242825351422126e-10, 2.6150563760737437e-10, -8.919200340033589e-10]
Episode: 50, Average total reward: 0.0053425241032739375, Epsilon: 0.62
Maximum reward: 8
Average reward: 0.0018690184912870031
Maximum reward: 8
Average reward: 2.696479094863794


In [35]:
class FCNNBinary(nn.Module):
    def __init__(self, num_node_features):
        super(FCNNBinary, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(num_node_features, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, data):
        x = data.x
        x = self.fc(x)
        return torch.sigmoid(x)
    
def qtable_to_feat(qtable: torch.Tensor, label):
    # In qtable, rows are states and columns are actions taken in that state
    return Data(x = torch.flatten(qtable), y = label) # Naive approach

In [36]:
UPS_agents = [QTableAgent(get_state_size(env), env.action_space.n) for _ in range(NUM_TRAIN_R_FUNCS)]

from torch_geometric.data import Data
for agent in UPS_agents:
    for row in agent.q_table:
        for i in range(len(row)):
            row[i] = np.random.uniform(-1, 1) # set each value to a random number between -1 and 1
dataset1 = [qtable_to_feat(torch.tensor(agent.q_table, dtype=torch.float32), 1) for agent in USS_agents]
dataset2 = [qtable_to_feat(torch.tensor(agent.q_table, dtype=torch.float32), 0) for agent in URS_agents] # URS = 1, UPS = 0

def generate_fcnn_data(dataset1, dataset2):
    indices = np.random.permutation(len(dataset1) + len(dataset2))
    data = [dataset1[i] if i < len(dataset1) else dataset2[i - len(dataset1)] for i in indices]
    for i in range(len(data)):
        data[i].y = 1.0 if indices[i] < len(dataset1) else 0.0 # Binary labels for each node; 1 = URS, 0 = UPS
        # Hence roughly speaking, 1 = more coherent, 0 = less coherent

    train_data_ratio = 0.8
    train_data, test_data = data[:int(train_data_ratio * len(data))], data[int(train_data_ratio * len(data)):]
    num_node_features = data[0].x.shape[0] # Number of features for each node
    return train_data, test_data, num_node_features

def train_fcnn_classifier(model, criterion, optimizer, train_data, test_data, epochs = 40, patience = 3, 
                          epochs_without_improvement = 0, best_loss = float('inf')):
    for epoch in range(epochs):
        avg_train_loss = 0
        for datapt in train_data:
            model.train()
            optimizer.zero_grad()

            out = model.forward(datapt)
            assert isinstance(out, torch.Tensor), f"Expected model.forward to return a tensor, but got {out}"
            loss = criterion(out, torch.tensor([datapt.y]))  # Adjust shape as necessary
            loss.backward()
            optimizer.step()
            avg_train_loss += loss.item()
        avg_train_loss /= len(train_data)

        avg_test_loss = 0
        for datapt in test_data:
            model.eval()
            with torch.no_grad():
                out = model.forward(datapt)
                loss = criterion(out, torch.tensor([datapt.y]))
                avg_test_loss += loss.item()
        avg_test_loss /= len(test_data)
        
        print(f'Epoch {epoch+1}: Average Train Loss: {avg_train_loss}, Average Test Loss: {avg_test_loss}')
        
        # Early Stopping
        if avg_test_loss < best_loss:
            best_loss = avg_test_loss
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= patience:
                print(f'Early stopping at epoch {epoch+1}')
                break

train_data, test_data, num_node_features = generate_fcnn_data(dataset1, dataset2)
print(num_node_features)
model = FCNNBinary(num_node_features)
criterion = torch.nn.BCELoss()  
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
train_fcnn_classifier(model, criterion, optimizer, train_data, test_data)

3000
Epoch 1: Average Train Loss: 0.049766413247157644, Average Test Loss: 2.836417075364506e-29
Epoch 2: Average Train Loss: 3.910217640168412e-06, Average Test Loss: 3.8434927429074585e-29
Epoch 3: Average Train Loss: 3.613662419610364e-06, Average Test Loss: 5.327220408929323e-29
Epoch 4: Average Train Loss: 3.296245868611436e-06, Average Test Loss: 7.391705247613961e-29
Early stopping at epoch 4


In [37]:
def generate_UQS_qagent(rand_qtable, gamma, env: gym.Env, episodes = 500):
    """
    Train a Q-table agent based on a reward function uniformly sampled from the set of 
    possible reward functions compatible with the given random Q-table."""
    # Generate the reward function using the Bellman equation
    r_table = np.zeros(rand_qtable.shape)
    for s in range(rand_qtable.shape[0]):
        for a in range(rand_qtable.shape[1]):
            env.reset()
            env.unwrapped.s = s
            ns = env.step(a)[0]
            r_table[s, a] = rand_qtable[s, a] - gamma * np.max(rand_qtable[ns]) #assuming greedy policy
    
    # Train the agent
    r_func = lambda s, a, *args: r_table[s, a]
    return train_qtable(env_name = env.spec.id, episodes = episodes, reward_function = r_func)

UQS_agents = [generate_UQS_qagent(agent.q_table, 0.9, env, episodes = NUM_EPS_TRAIN_R) for agent in UPS_agents]

ValueError: setting an array element with a sequence.

In [38]:
def generate_UVS_qagent(rand_values, gamma, env: gym.Env, episodes = 500, lb = -1, ub = 1):
    """
    Train a Q-table agent based on a reward function uniformly sampled from the set of 
    possible reward functions compatible with the given values for each state.
    Assumes a uniform distribution between [lb, ub]."""
    r_table = np.zeros((len(rand_values), env.action_space.n))
    for s in range(len(rand_values)):
        next_states = np.zeros(env.action_space.n)
        for a in range(env.action_space.n):
            env.reset()
            env.unwrapped.s = s
            next_states[a] = env.step(a)[0]
        #v(s) = max_a(R(s, a) + gamma * v(s'))
        reward_ub = rand_values[s] - np.array([gamma * rand_values[int(ns)] for ns in next_states])
        taut_probs = np.zeros(env.action_space.n)
        for i in range(env.action_space.n):
            all_except_i = np.delete(np.arange(env.action_space.n), i)
            taut_probs[i] = np.prod((reward_ub[all_except_i] + 1) / 2)
            # probability that all other rewards at action j are less than reward_ub[j]
        taut_probs /= np.sum(taut_probs)
        taut = np.random.choice(env.action_space.n, p = taut_probs) 
        #index of the action where the reward is equal to the maximum

        rewards = np.full(env.action_space.n, float('inf'))
        while np.any(rewards >= reward_ub): #while any of the rewards are greater than the upper bound
            rewards = np.random.uniform(-1, 1, env.action_space.n)
        rewards[taut] = reward_ub[taut]
        r_table[s] = rewards
    
    r_func = lambda s, a, *args: r_table[s, a]
    return train_qtable(env_name = env.spec.id, episodes = episodes, reward_function = r_func)

# UVS_agents = [generate_UVS_qagent(np.random.uniform(-1, 1, env.unwrapped.s), 0.9, env, episodes = NUM_EPS_TRAIN_R) for _ in range(NUM_TRAIN_R_FUNCS)]
# this currently takes way too long so it has been commented out

In [39]:
UQS_agents[0].q_table - UPS_agents[0].q_table

NameError: name 'UQS_agents' is not defined

In [40]:
### Only attaching reward to terminal states (kind of like UUS? but with the inductive biases)

def det_rand_terminal(done: bool, *args, lb = -1, ub = 1, sparsity = 0.0):
    """
    Create a deterministic random number generator for a given set of arguments.
    Used to generate deterministic reward functions for the coherence classifier. """
    global NUM_REWARD_CALLS
    NUM_REWARD_CALLS += 1
    if not done:
        return random.uniform(-NEAR_ZERO, NEAR_ZERO)
    unique_seed = f"{args}".encode("utf-8")
    random.seed(unique_seed)
    return random.uniform(lb, ub) if random.random() > sparsity else random.uniform(-NEAR_ZERO, NEAR_ZERO)

UUS_agents = [train_qtable(env_name = env_name, episodes = NUM_EPS_TRAIN_R, 
                           reward_function = lambda *args: det_rand_terminal(*args)) for _ in range(NUM_TRAIN_R_FUNCS)]

In [41]:
### Turn the state and action space of Taxi-v3 into a graph

from collections import defaultdict
taxi_env = gym.make("Taxi-v3")
taxi_env.reset()
# Initialize containers for graph data
edges = defaultdict(list)
edge_attr = defaultdict(list)

# A helper function to encode the state into a single number (node index)
def state_to_node(taxi_row, taxi_col, pass_loc, dest_idx):
    # This encoding assumes specific knowledge about the Taxi-v3 state space size
    return taxi_row * 100 + taxi_col * 20 + pass_loc * 4 + dest_idx
    # max = 4 * 100 + 4 * 20 + 4 * 4 + 3 = 400 + 80 + 16 + 3 = 499

# Iterate through all possible states and actions to construct the graph
for taxi_row in range(5):
    for taxi_col in range(5):
        for pass_loc in range(5):  # 4 locations + 1 for 'in taxi'
            for dest_idx in range(4):
                current_state = state_to_node(taxi_row, taxi_col, pass_loc, dest_idx)
                for action in range(taxi_env.action_space.n):
                    # Set the environment to the current state
                    taxi_env.unwrapped.s = current_state
                    # Take action and observe the next state and reward
                    next_state, reward, done, _ = taxi_env.step(action)
                    # Add edge from current state to next state
                    edges[current_state].append(next_state)
                    # Optionally, use rewards as edge attributes
                    # edge_attr[(current_state, next_state)].append(reward)
                    taxi_env.reset()


# Convert edges and edge attributes to tensors
edge_index = []
for src, dsts in edges.items():
    for dst in dsts:
        edge_index.append([src, dst])
edge_index = torch.tensor(edge_index).t().contiguous()
edge_index, edge_index.shape

(tensor([[  0,   0,   0,  ..., 499, 499, 499],
         [100,   0,  20,  ..., 479, 499, 499]]),
 torch.Size([2, 3000]))

In [105]:
def greedy_policy(q_table):
    return torch.tensor(np.argmax(q_table, axis=1).reshape(-1, 1).astype(np.float32))
def random_policy(state_dim):
    return torch.randint(0, 6, (state_dim, 1)).float()
def prep_qtable(q_table):
    return torch.tensor(q_table, dtype=torch.float32)

dataset1 = [Data(x = greedy_policy(agent.q_table), edge_index = edge_index, y = 1) for agent in UUS_agents]
dataset2 = [Data(x = greedy_policy(agent.q_table), edge_index = edge_index, y = 0) for agent in URS_agents]
# dataset2 = [Data(x = random_policy(agent.q_table.shape[0]), edge_index = edge_index, y = 0) for agent in UPS_agents]
# ^ random_policy = UPS sampling
print(dataset1[0].x.shape)

train_data, test_data, num_node_features = generate_data(dataset1, dataset2)
model = GraphLevelGCN(num_node_features)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
train_classifier(model, criterion, optimizer, train_data, test_data, epochs = 80, patience = 5)

torch.Size([500, 1])
Epoch 1: Average Train Loss: 0.6975951328873634, Average Test Loss: 0.6983259975910187
Epoch 2: Average Train Loss: 0.6948948062956333, Average Test Loss: 0.6966440558433533
Epoch 3: Average Train Loss: 0.6944274820387364, Average Test Loss: 0.6957786858081818
Epoch 4: Average Train Loss: 0.6941709764301777, Average Test Loss: 0.695288798213005
Epoch 5: Average Train Loss: 0.6939983785152435, Average Test Loss: 0.6949822187423706
Epoch 6: Average Train Loss: 0.693861399590969, Average Test Loss: 0.6947700291872024
Epoch 7: Average Train Loss: 0.6937379993498325, Average Test Loss: 0.6945965200662613
Epoch 8: Average Train Loss: 0.6936026096343995, Average Test Loss: 0.6944513291120529
Epoch 9: Average Train Loss: 0.6934472382068634, Average Test Loss: 0.6942998260259629
Epoch 10: Average Train Loss: 0.6932599917054176, Average Test Loss: 0.6941128522157669
Epoch 11: Average Train Loss: 0.6930280290544033, Average Test Loss: 0.6938772439956665
Epoch 12: Average Trai

{'train_loss': 0.07837411074433476, 'test_loss': 0.04979511229321361}

In [None]:
import wandb
def train(project_name: str, model, criterion, train_data, test_data):
    wandb.init(project=project_name)
    sweep_config = {
        "method": "random",
        "parameters": {
            "lr": {
                "distribution": "log_uniform_values", 
                "min": 1e-4,
                "max": 1e-1
            },
            "weight_decay": {
                "values": [5e-4, 1e-4]
            },
            "epochs": {"values": [10, 20, 40, 80]}
        },
        "metric": {"goal": "minimize", "name": "test_loss"}
    }
    optimizer = torch.optim.Adam(model.parameters(), lr=wandb.config["lr"], weight_decay=wandb.config["weight-decay"])
    metrics = train_classifier(model, criterion, optimizer, train_data, test_data)
    wandb.log(metrics)
wandb.log({})

In [106]:
print(torch.transpose(dataset1[0].x, 0, -1)[0:10, 0:10]) # Example greedy policies
print(torch.transpose(dataset2[0].x, 0, -1)[0:10, 0:10])
print(model.forward(dataset1[0]))
print(model.forward(dataset2[0]))

tensor([[0., 0., 0., 2., 3., 0., 0., 0., 2., 1.]])
tensor([[0., 0., 3., 0., 0., 0., 0., 0., 0., 3.]])
tensor([[0.9858]], grad_fn=<SigmoidBackward0>)
tensor([[0.0830]], grad_fn=<SigmoidBackward0>)


In [146]:
taxi_model = train_qtable(env_name = "Taxi-v3", episodes = 20000, verbose = True, print_every = 2000, 
                          return_reward = False)
taxi_data = Data(x = torch.tensor(greedy_policy(taxi_model.q_table)), edge_index = edge_index)

Episode: 2000, Average total reward: -298.0345172586293, Epsilon: 0.03
Episode: 4000, Average total reward: -98.47173586793397, Epsilon: 0.01
Episode: 6000, Average total reward: -49.28064032016008, Epsilon: 0.01
Episode: 8000, Average total reward: -19.100550275137568, Epsilon: 0.01
Episode: 10000, Average total reward: -3.5672836418209104, Epsilon: 0.01
Episode: 12000, Average total reward: 2.729864932466233, Epsilon: 0.01
Episode: 14000, Average total reward: 5.246623311655828, Epsilon: 0.01
Episode: 16000, Average total reward: 6.6213106553276635, Epsilon: 0.01
Episode: 18000, Average total reward: 7.011505752876438, Epsilon: 0.01
Episode: 20000, Average total reward: 7.078039019509755, Epsilon: 0.01


  taxi_data = Data(x = torch.tensor(greedy_policy(taxi_model.q_table)), edge_index = edge_index)


In [147]:
print(taxi_data.x.T)
print(model.forward(taxi_data))

tensor([[0., 4., 4., 4., 3., 0., 2., 0., 0., 0., 0., 0., 1., 0., 0., 0., 5., 0.,
         0., 0., 0., 3., 3., 3., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         3., 0., 3., 0., 0., 0., 0., 0., 2., 0., 2., 2., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 2., 0., 1., 0., 0., 0., 0., 2., 0., 2., 2., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 2., 0., 0., 0., 2., 2., 3., 4., 0., 4., 4., 1., 2.,
         0., 4., 0., 0., 0., 0., 3., 5., 3., 0., 0., 1., 1., 1., 2., 0., 0., 0.,
         0., 0., 0., 0., 2., 0., 0., 0., 1., 0., 0., 0., 0., 3., 3., 3., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 3., 0., 3., 0., 0., 0., 0., 0.,
         2., 0., 2., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 2., 0., 0., 0., 0.,
         0., 0., 2., 0., 2., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 2., 0., 0.,
         0., 3., 3., 3., 1., 0., 1., 1., 3., 3., 0., 3., 3., 3., 3., 0., 3., 1.,
         3., 3., 0., 1., 1., 1., 2., 0., 2., 2., 0., 0., 0., 0., 2., 2., 2., 0.,
         1., 2., 0., 2., 0.,

In [148]:
test_qtable(gym.make("Taxi-v3"), taxi_model, episodes = 1000)
taxi_model.q_table

Maximum reward: 8


Average reward: -9.875


array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ],
       [-3.45631803, -3.40184505, -3.44935417, -3.30738336,  9.61017345,
        -4.00532755],
       [-1.32423015, -1.4053787 , -1.69223727, -1.42928145, 14.1179476 ,
        -3.41292327],
       ...,
       [-0.79663486, -0.68870221, -0.79279359, -0.79716306, -2.42284272,
        -2.02425586],
       [-2.18994752, -2.18745726, -2.19043425, -1.79490775, -2.2865236 ,
        -2.63298791],
       [-0.04946138, -0.05038011, -0.06737791,  2.01012662, -0.58588332,
        -0.58404418]])

In [151]:
class Node:
    def __init__(self, state, parent=None, action=None, q_values=None):
        self.state = state
        self.parent = parent
        self.action = action  # Action taken to reach this node
        self.children = []
        self.visits = 1  # Initialize to avoid division by zero
        self.value = 0
        self.q_values = q_values  # This should be a dictionary or similar structure

    def add_child(self, child):
        self.children.append(child)

    def update(self, reward):
        self.visits += 1
        self.value += reward

    def is_fully_expanded(self, env):
        return len(self.children) == env.action_space.n

    def best_child(self, c_param=1.4):
        choices_weights = [
            (child.value / child.visits) + c_param * np.sqrt((2 * np.log(self.visits) / child.visits))
            for child in self.children
        ]
        return self.children[np.argmax(choices_weights)]

def rollout_policy(state, q_table, env):
    # Use the Q-table to select the best action if this state has been seen
    if state in q_table:
        return np.argmax(q_table[state])
    else:
        # Otherwise, select a random action
        return env.action_space.sample()

def selection(node, env):
    while not node.is_fully_expanded(env):
        if not node.children:
            return expansion(node, env)
        else:
            node = node.best_child()
    return node

def expansion(node, env):
    tried_actions = [child.action for child in node.children]
    for action in range(env.action_space.n):
        if action not in tried_actions:
            env.env.s = node.state  # Set environment to current node's state
            next_state, _, _, _ = env.step(action)
            new_node = Node(next_state, parent=node, action=action, q_values=node.q_values)
            node.add_child(new_node)
            return new_node
    return node  # In case all actions were tried

def simulation(node, env, max_steps=100):
    total_reward = 0
    current_state = node.state
    steps = 0

    while steps < max_steps:
        action = rollout_policy(current_state, node.q_values, env)
        env.env.s = current_state
        next_state, reward, done, _ = env.step(action)
        total_reward += reward
        current_state = next_state
        steps += 1
        if done:
            break

    return total_reward

def backpropagation(node, reward):
    while node is not None:
        node.update(reward)
        node = node.parent

def mcts(root, env, iterations=1000):
    for _ in range(iterations):
        leaf = selection(root, env)
        reward = simulation(leaf, env)
        backpropagation(leaf, reward)

# Example usage
env_name = "Taxi-v3"
env = gym.make(env_name)
initial_state = env.reset()

# Assume taxi_model.q_table is your pre-trained Q-table
# It should be a dictionary where keys are states and values are arrays of Q-values for each action
q_table = taxi_model.q_table  # Replace this with your actual Q-table

root_node = Node(initial_state, q_values=q_table)
mcts(root_node, env, iterations=1000)

In [152]:
### Test MCTS

def choose_action(node):
    # Choose the child with the highest visit count
    if node.children:
        return max(node.children, key=lambda child: child.visits).action
    else:
        return None

def simulate_episode_from_root(env, root_node):
    total_reward = 0
    done = False
    current_node = root_node
    env.reset()
    env.env.s = current_node.state
    
    while not done and current_node is not None:
        action = choose_action(current_node)
        if action is None:
            # No more information in the tree; choose random action
            action = env.action_space.sample()
        
        next_state, reward, done, _ = env.step(action)  # Execute the chosen action
        total_reward += reward
        
        # Move to the next node in the tree, if it exists
        next_node = None
        for child in current_node.children:
            if child.action == action:
                next_node = child
                break
        current_node = next_node

    return total_reward


# Test the policy derived from the MCTS root node
env = gym.make('Taxi-v3')
average_reward = np.mean([simulate_episode_from_root(env, root_node) for _ in range(100)])
print(f"Average Reward from the MCTS policy: {average_reward}")
test_qtable(env, taxi_model, episodes = 100)


Average Reward from the MCTS policy: -200.0
Maximum reward: 8
Average reward: -8.52


- GNNs over the environment work really well, even on USS/URS and UUS/URS (identifying "sparsity")
    - Real Taxi agents are graded strongly towards USS and UUS end
    - Looking at the q-tables, there are some noticible differences (e.g. USS q-tables tend to have lower magnitude)
    - GNNs don't work yet on USS/URS when only the policy is passed in
    - Started working once I put the fix in of changing the reward function of the environment *for each state*
- Luckily the classification of a good Taxi agent (3000 eps) under USS/URS is not too high (p = 0.9865)?

In [153]:
# Generate tabular policy from MCTS and feed through classifier

def extract_policy(root_node, env):
    policy = np.random.randint(0, env.action_space.n, env.observation_space.n)
    # default action is random in case the state is not in the tree
    node_queue = [root_node]
    
    num_not_random = 0
    while node_queue:
        num_not_random += 1
        current_node = node_queue.pop(0)
        if current_node.is_fully_expanded(env):
            best_action = current_node.best_child().action
            policy[current_node.state] = best_action
            node_queue.extend(current_node.children)
        else:
            # If the node isn't fully expanded, we take the best action tried so far
            # This is rare in fully run MCTS but can happen if the tree isn't deep enough
            if current_node.children:
                best_action = max(current_node.children, key=lambda x: x.visits).action
                policy[current_node.state] = best_action
                node_queue.extend(current_node.children)

    return policy, num_not_random  

mcts_policy, num_not_random = extract_policy(root_node, env)
model.forward(Data(x = torch.tensor(mcts_policy.reshape(-1, 1).astype(np.float32)), edge_index = edge_index))

tensor([[1.0000]], grad_fn=<SigmoidBackward0>)

In [155]:
### A la Wentworth's definition of coherence, we create policies that do and do not "contradict"
# themselves, i.e. there exists a value function consistent with the policy, and pass them
# through the classifier

coherent_policy = greedy_policy(taxi_model.q_table).detach()
incoherent_policy = greedy_policy(taxi_model.q_table).detach()
for _ in range(100):
    env.reset()
    i = env.unwrapped.s # +100 for moving one row, + 20 for moving one column
    env.step(coherent_policy[i].item())
    j = env.unwrapped.s
    if coherent_policy[i][0] % 2 == 0:
        incoherent_policy[j][0] = coherent_policy[i][0] + 1
    else:
        incoherent_policy[j][0] = coherent_policy[i][0] - 1 # if 0, then 1; if 1, then 0
    # point is to put incoherent_policy in a loop

print((coherent_policy != incoherent_policy).nonzero().shape[0])
print(model.forward(Data(x = coherent_policy.detach(), edge_index = edge_index)))
print(model.forward(Data(x = incoherent_policy.detach(), edge_index = edge_index)))

75
tensor([[0.6893]], grad_fn=<SigmoidBackward0>)
tensor([[0.6078]], grad_fn=<SigmoidBackward0>)


In [156]:
class PolicyAgent:
    def __init__(self, policy, epsilon = 0.1):
        self.policy, self.epsilon = policy, epsilon
    def act(self, state, epsilon):
        if random.random() > epsilon:
            action = self.policy[state]
        else:
            action = random.randrange(self.action_dim)
        return action
    
c_agent, ic_agent = PolicyAgent(np.array(coherent_policy.T[0])), PolicyAgent(np.array(incoherent_policy.T[0]))
test_qtable(env, c_agent, episodes = 1000)
test_qtable(env, ic_agent, episodes = 1000)

Maximum reward: 8
Average reward: -6.553
Maximum reward: 8
Average reward: -182.233
