In [1]:
import gym
import numpy as np
import random
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch_geometric.nn.conv import GCNConv
import math

Setup:
- Starting point: just try to train classifier on RL policies

In [2]:
### DQN implementation

# Define the Q-network
class DQN(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super(DQN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

    def forward(self, x):
        return self.fc(x)
    
    def get_weights(self):
        return self.state_dict()

# Experience Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        state = np.expand_dims(state, 0)
        next_state = np.expand_dims(next_state, 0)
            
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.concatenate(state), action, reward, np.concatenate(next_state), done

    def __len__(self):
        return len(self.buffer)

# DQN Agent
class DQNAgent:
    def __init__(self, state_dim, action_dim, hidden_dim=64, lr=1e-3, batch_size=64, gamma=0.99, replay_size=1000):
        self.model = DQN(state_dim, action_dim, hidden_dim)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.replay_buffer = ReplayBuffer(replay_size)
        self.batch_size = batch_size
        self.gamma = gamma
        self.action_dim = action_dim
    
    def update(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        state, action, reward, next_state, done = self.replay_buffer.sample(self.batch_size)
        
        state = torch.FloatTensor(state)
        next_state = torch.FloatTensor(next_state)
        action = torch.LongTensor(action)
        reward = torch.FloatTensor(reward)
        done = torch.FloatTensor(done)

        q_values = self.model(state)
        next_q_values = self.model(next_state)
        
        q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
        next_q_value = next_q_values.max(1)[0]
        expected_q_value = reward + self.gamma * next_q_value * (1 - done)
        
        loss = nn.MSELoss()(q_value, expected_q_value.detach())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def act(self, state, epsilon):
        if random.random() > epsilon:
            state = torch.FloatTensor(np.expand_dims(state, 0))
            q_value = self.model(state)
            action = q_value.max(1)[1].item()
        else:
            action = random.randrange(self.action_dim)
        return action

In [3]:
NUM_NON_ZERO_REWARDS = 0
def train_dqn(env_name="CartPole-v1", episodes=500, epsilon_start=1.0, epsilon_final=0.01, 
              epsilon_decay=500, reward_function = None):
    global NUM_NON_ZERO_REWARDS
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    agent = DQNAgent(state_dim, action_dim)
    
    epsilon_by_frame = lambda frame_idx: epsilon_final + (epsilon_start - epsilon_final) * np.exp(-1. * frame_idx / epsilon_decay)
    
    for episode in range(episodes):
        state = env.reset()
        episode_reward = 0
        while True:
            epsilon = epsilon_by_frame(episode)
            action = agent.act(state, epsilon)
            next_state, reward, done, _ = env.step(action)
            if reward_function:
                reward = reward_function(state, action, next_state, done)
            NUM_NON_ZERO_REWARDS += 0 if math.isclose(reward, 0) else 1
            agent.replay_buffer.push(state, action, reward, next_state, done)
            state = next_state
            episode_reward += reward
            
            agent.update()
            
            if done:
                break
            # print(f"Episode: {episode+1}, Total reward: {episode_reward}, Epsilon: {epsilon:.2f}")

        # Optional: Render the environment to visualize training progress
        # if episode % 100 == 0:
        #     render_env(env, agent)

    env.close()
    return agent

# Optional: Function to render the environment with the current policy
def render_env(env, agent):
    state = env.reset()
    done = False
    while not done:
        action = agent.act(state, 0)  # Using 0 epsilon for greedy action selection
        # print(env.step(action))
        next_state, reward, done, _ = env.step(action)
        env.render()
        state = next_state

In [4]:
def test_dqn(env, agent, episodes=10):
    print(f"Maximum reward: {env.spec.reward_threshold}")
    for episode in range(episodes):
        # if episode == 0:
        #     render_env(env, agent)
        state = env.reset()
        episode_reward = 0
        done = False
        while not done:
            action = agent.act(state, 0)  # Using 0 epsilon for greedy action selection
            next_state, reward, done, _ = env.step(action)
            episode_reward += reward
            state = next_state
        print(f"Episode: {episode+1}, Total reward: {episode_reward}")

In [5]:
env_name = "CartPole-v1"
agent = train_dqn(env_name = env_name, episodes = 500)

In [6]:
test_dqn(gym.make("CartPole-v1"), agent)
agent.model.get_weights()

Maximum reward: 475.0
Episode: 1, Total reward: 172.0
Episode: 2, Total reward: 166.0
Episode: 3, Total reward: 228.0
Episode: 4, Total reward: 157.0
Episode: 5, Total reward: 208.0
Episode: 6, Total reward: 298.0
Episode: 7, Total reward: 500.0
Episode: 8, Total reward: 280.0
Episode: 9, Total reward: 500.0
Episode: 10, Total reward: 500.0


OrderedDict([('fc.0.weight',
              tensor([[ 2.2510e+00, -3.3805e-01, -3.9947e+00, -7.4218e-01],
                      [ 5.8213e-01, -5.0631e-01, -6.0192e+00, -1.4179e+00],
                      [-6.2113e-02,  9.3656e-01, -4.7122e-01, -9.8996e-01],
                      [ 7.4555e-01, -6.4955e-01, -1.5013e+00, -1.8160e-01],
                      [-2.5369e-01,  7.5536e-01,  5.6088e+00,  1.0261e+00],
                      [-5.4400e-01,  7.0318e-02, -6.0223e-01, -5.3441e-01],
                      [ 1.6757e-01,  1.7796e-01, -8.1427e-01,  7.7093e-02],
                      [ 3.9939e-01, -4.9977e-01, -6.1514e+00, -1.2189e+00],
                      [ 1.4535e-01, -1.1467e-01, -9.4931e-01, -5.1248e-01],
                      [-3.6525e-01, -5.6429e-01, -5.0702e-01, -4.6492e-01],
                      [-3.2165e-01,  3.9664e-01, -4.8892e-02, -6.2249e-02],
                      [-4.0101e-01,  4.6125e-02, -7.8735e-01, -5.1730e-01],
                      [ 5.0324e-01, -6.5878e-01, -4.7073e+0

In [7]:
### Coherence classifier

#agent.model.get_weights()

# Define a simple GCN model
from torch_geometric.data import Data
class GCN(torch.nn.Module):
    def __init__(self, data):
        super(GCN, self).__init__()
        # Define the GCN layers
        self.conv1 = GCNConv(data.num_node_features, 4)  # Input features to hidden
        self.conv2 = GCNConv(4, 2)  # Hidden to output features
        self.data = data

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        # Pass data through the first GCN layer, then apply ReLU
        x = torch.relu(self.conv1(x, edge_index))
        # Pass data through the second GCN layer
        x = self.conv2(x, edge_index)
        return x
    
    
def nn_to_data(model: nn.Module) -> Data:
    edges = []
    node_features = []

    # Counter for global neuron index
    global_neuron_index = 0

    # Iterate over each layer in the network
    base = next(model.children())
    if isinstance(base, nn.Sequential):
        layers = base.children()
    else:
        layers = model.children()

    if isinstance(base, nn.Sequential):
        input_dim = base[0].in_features
        node_features = np.zeros(input_dim)

    for layer in layers:
        if isinstance(layer, nn.Linear):
            # Update edges based on the weight matrix
            for i in range(layer.weight.shape[1]):  # Input neurons
                for j in range(layer.weight.shape[0]):  # Output neurons
                    edges.append((global_neuron_index + i, global_neuron_index + layer.weight.shape[1] + j))
            
            # Update node features (e.g., biases)
            extension = layer.bias.detach().numpy()
            node_features = np.append(node_features, extension)
            
            # Update the global neuron index
            global_neuron_index += layer.weight.shape[1]

    # Convert lists to PyTorch tensors
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    x = torch.tensor(node_features, dtype=torch.float).view(-1, 1)
    return Data(x=x, edge_index=edge_index)

data = nn_to_data(agent.model)
gcn = GCN(data)
# data.x.shape, data.edge_index.shape
# print(data.x)

#Debug
out_of_bounds = data.edge_index >= data.x.shape[0]
if out_of_bounds.any():
    print("Out-of-bounds indices found at locations:")
    print(data.edge_index[:, out_of_bounds.any(dim=0)])

In [8]:
# Dataset generation
env = gym.make(env_name)
NEAR_ZERO = 1e-9
NUM_REWARD_CALLS = 0
NUM_NON_ZERO_REWARDS = 0
def deterministic_random(*args, lb = -1, ub = 1, sparsity = 0.0):
    global NUM_REWARD_CALLS
    NUM_REWARD_CALLS += 1
    unique_seed = f"{args}".encode("utf-8")
    random.seed(unique_seed)
    return random.uniform(lb, ub) if random.random() > sparsity else random.uniform(-NEAR_ZERO, NEAR_ZERO)

NUM_TRAIN_R_FUNCS = 50
NUM_EPS_TRAIN_R = 50
URS_r_funcs = [lambda s, a, ns, d: deterministic_random(s, a, ns, d) for _ in range(NUM_TRAIN_R_FUNCS)]
URS_agents = [train_dqn(env_name = env_name, 
                        episodes=NUM_EPS_TRAIN_R, reward_function=r_func) for r_func in URS_r_funcs]
USS_r_funcs = [lambda s, a, ns, d: deterministic_random(s, a, ns, d, sparsity=0.99) for _ in range(NUM_TRAIN_R_FUNCS)]
print(f"Number of reward function calls: {NUM_REWARD_CALLS}")
print(f"Number of non-zero rewards: {NUM_NON_ZERO_REWARDS}")
USS_agents = [train_dqn(env_name = env_name, 
                        episodes=NUM_EPS_TRAIN_R, reward_function=r_func) for r_func in USS_r_funcs]

Number of reward function calls: 58866
Number of non-zero rewards: 58866


In [13]:
UPS_agents = [DQNAgent(env.observation_space.shape[0], env.action_space.n) for _ in range(NUM_TRAIN_R_FUNCS)]

import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool, GCNConv

class GraphLevelGCN(torch.nn.Module):
    def __init__(self, num_node_features):
        super(GraphLevelGCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 16)
        self.conv2 = GCNConv(16, 16)
        self.linear = torch.nn.Linear(16, 1)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        
        # Aggregate node features to graph-level features
        x = global_mean_pool(x, batch)
        
        # Make a binary classification prediction
        x = self.linear(x)
        return torch.sigmoid(x)

# Training loop
USS_data = [nn_to_data(agent.model) for agent in USS_agents]
URS_data = [nn_to_data(agent.model) for agent in URS_agents]
print(URS_data[0].x.shape)
UPS_data = [nn_to_data(agent.model) for agent in UPS_agents]
assert URS_data[0].x.shape == UPS_data[0].x.shape

# Binary classification between two datasets
dataset1 = URS_data
dataset2 = UPS_data
indices = np.random.permutation(len(dataset1) + len(dataset2))
data = [dataset1[i] if i < len(dataset1) else dataset2[i - len(dataset1)] for i in indices]
for i in range(len(data)):
    data[i].y = 1.0 if indices[i] < len(dataset1) else 0.0 # Binary labels for each node; 1 = URS, 0 = UPS
    # Hence roughly speaking, 1 = more coherent, 0 = less coherent

train_data_ratio = 0.8
train_data, test_data = data[:int(train_data_ratio * len(data))], data[int(train_data_ratio * len(data)):]
# Loss and optimizer
num_node_features = 1 # just the bias term
model = GraphLevelGCN(num_node_features)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

epochs = 40
# Set the number of epochs to wait for early stopping
patience = 3
# Initialize variables for early stopping
best_loss = float('inf')
epochs_without_improvement = 0

for epoch in range(epochs):
    avg_train_loss = 0
    for datapt in train_data:
        model.train()
        optimizer.zero_grad()

        # print(f"datapt.x shape: {datapt.x.shape}")  # Should be [num_nodes, num_node_features]
        # print(f"datapt.edge_index shape: {datapt.edge_index.shape}")  # Should be [2, num_edges]
        out = model.forward(datapt)
        # print(out.size())
        # print(torch.tensor([[datapt.y]]).size())
        loss = criterion(out, torch.tensor([[datapt.y]]))  # Adjust shape as necessary
        loss.backward()
        optimizer.step()
        avg_train_loss += loss.item()
    avg_train_loss /= len(train_data)

    avg_test_loss = 0
    for datapt in test_data:
        model.eval()
        with torch.no_grad():
            out = model.forward(datapt)
            loss = criterion(out, torch.tensor([[datapt.y]]))
            avg_test_loss += loss.item()
    avg_test_loss /= len(test_data)
    
    print(f'Epoch {epoch+1}: Average Train Loss: {avg_train_loss}, Average Test Loss: {avg_test_loss}')
    
    # Early Stopping
    if avg_test_loss < best_loss:
        best_loss = avg_test_loss
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break

torch.Size([134, 1])
Epoch 1: Average Train Loss: 0.6972218137234449, Average Test Loss: 0.7209133118391037
Epoch 2: Average Train Loss: 0.6848836325109005, Average Test Loss: 0.6991690650582314
Epoch 3: Average Train Loss: 0.6643256038427353, Average Test Loss: 0.6459048129618168
Epoch 4: Average Train Loss: 0.6055300385749434, Average Test Loss: 0.5488033395260572
Epoch 5: Average Train Loss: 0.5094857397722083, Average Test Loss: 0.40308779150946067
Epoch 6: Average Train Loss: 0.4064270003905278, Average Test Loss: 0.26802050472770134
Epoch 7: Average Train Loss: 0.3309012855803303, Average Test Loss: 0.1936153618153192
Epoch 8: Average Train Loss: 0.2804601050406063, Average Test Loss: 0.15175510363769718
Epoch 9: Average Train Loss: 0.24652311684744072, Average Test Loss: 0.12628575226699468
Epoch 10: Average Train Loss: 0.2230531989687224, Average Test Loss: 0.10852708254624303
Epoch 11: Average Train Loss: 0.20634262898853475, Average Test Loss: 0.09578468498843903
Epoch 12: Av

In [14]:
# Test GCN model on a "more powerful" NN
print(model.forward(dataset1[0]))
print(model.forward(dataset2[0]))
powerful_models = [nn_to_data(train_dqn(env_name = env_name, episodes = 5 * i).model) 
                   for i in [1, 3, 10]]
print([model.forward(data) for data in powerful_models])

tensor([[0.9984]], grad_fn=<SigmoidBackward0>)
tensor([[0.0067]], grad_fn=<SigmoidBackward0>)
[tensor([[0.2387]], grad_fn=<SigmoidBackward0>), tensor([[1.]], grad_fn=<SigmoidBackward0>), tensor([[1.]], grad_fn=<SigmoidBackward0>)]


- The classifier training process is finicky -- sometimes it overfits, sometimes it underfits -- but sometimes can reach very low loss (< 0.002)
- Even weak classifiers classify powerful models (a.k.a. agents with >15 episodes in CartPole) as having P(URS) = 1, corresponding to coherence ~ $\infty$
- P(USS) / P(URS) is still having trouble as a metric; seems extremely difficult to detect differences between USS and URS-generated policies here with current methods
    - We will need some kind of "more advanced" coherence metric to distinguish more advanced policies; TODO: implement UUS somehow