In [1]:
import gym
import numpy as np
import random
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch_geometric.nn.conv import GCNConv

Setup:
- Starting point: just try to train classifier on RL policies

In [2]:
### DQN implementation

# Define the Q-network
class DQN(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super(DQN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

    def forward(self, x):
        return self.fc(x)
    
    def get_weights(self):
        return self.state_dict()

# Experience Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        state = np.expand_dims(state, 0)
        next_state = np.expand_dims(next_state, 0)
            
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.concatenate(state), action, reward, np.concatenate(next_state), done

    def __len__(self):
        return len(self.buffer)

# DQN Agent
class DQNAgent:
    def __init__(self, state_dim, action_dim, hidden_dim=64, lr=1e-3, batch_size=64, gamma=0.99, replay_size=1000):
        self.model = DQN(state_dim, action_dim, hidden_dim)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.replay_buffer = ReplayBuffer(replay_size)
        self.batch_size = batch_size
        self.gamma = gamma
        self.action_dim = action_dim
    
    def update(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        state, action, reward, next_state, done = self.replay_buffer.sample(self.batch_size)
        
        state = torch.FloatTensor(state)
        next_state = torch.FloatTensor(next_state)
        action = torch.LongTensor(action)
        reward = torch.FloatTensor(reward)
        done = torch.FloatTensor(done)

        q_values = self.model(state)
        next_q_values = self.model(next_state)
        
        q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
        next_q_value = next_q_values.max(1)[0]
        expected_q_value = reward + self.gamma * next_q_value * (1 - done)
        
        loss = nn.MSELoss()(q_value, expected_q_value.detach())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def act(self, state, epsilon):
        if random.random() > epsilon:
            state = torch.FloatTensor(np.expand_dims(state, 0))
            q_value = self.model(state)
            action = q_value.max(1)[1].item()
        else:
            action = random.randrange(self.action_dim)
        return action

In [3]:
def train_dqn(env_name="CartPole-v1", episodes=500, epsilon_start=1.0, epsilon_final=0.01, 
              epsilon_decay=500, reward_function = None):
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    agent = DQNAgent(state_dim, action_dim)
    
    epsilon_by_frame = lambda frame_idx: epsilon_final + (epsilon_start - epsilon_final) * np.exp(-1. * frame_idx / epsilon_decay)
    
    for episode in range(episodes):
        state = env.reset()
        episode_reward = 0
        while True:
            epsilon = epsilon_by_frame(episode)
            action = agent.act(state, epsilon)
            next_state, reward, done, _ = env.step(action)
            if reward_function:
                reward = reward_function(state, action, next_state, done)
            agent.replay_buffer.push(state, action, reward, next_state, done)
            state = next_state
            episode_reward += reward
            
            agent.update()
            
            if done:
                break
            # print(f"Episode: {episode+1}, Total reward: {episode_reward}, Epsilon: {epsilon:.2f}")

        # Optional: Render the environment to visualize training progress
        # if episode % 100 == 0:
        #     render_env(env, agent)

    env.close()
    return agent

# Optional: Function to render the environment with the current policy
def render_env(env, agent):
    state = env.reset()
    done = False
    while not done:
        action = agent.act(state, 0)  # Using 0 epsilon for greedy action selection
        # print(env.step(action))
        next_state, reward, done, _ = env.step(action)
        env.render()
        state = next_state

In [4]:
def test_dqn(env, agent, episodes=10):
    print(f"Maximum reward: {env.spec.reward_threshold}")
    for episode in range(episodes):
        # if episode == 0:
        #     render_env(env, agent)
        state = env.reset()
        episode_reward = 0
        done = False
        while not done:
            action = agent.act(state, 0)  # Using 0 epsilon for greedy action selection
            next_state, reward, done, _ = env.step(action)
            episode_reward += reward
            state = next_state
        print(f"Episode: {episode+1}, Total reward: {episode_reward}")

In [5]:
env_name = "CartPole-v1"
agent = train_dqn(env_name = env_name, episodes = 500)

In [6]:
test_dqn(gym.make("CartPole-v1"), agent)
agent.model.get_weights()

Maximum reward: 475.0
Episode: 1, Total reward: 500.0
Episode: 2, Total reward: 500.0
Episode: 3, Total reward: 500.0
Episode: 4, Total reward: 500.0
Episode: 5, Total reward: 500.0
Episode: 6, Total reward: 500.0
Episode: 7, Total reward: 500.0
Episode: 8, Total reward: 500.0
Episode: 9, Total reward: 500.0
Episode: 10, Total reward: 500.0


OrderedDict([('fc.0.weight',
              tensor([[ 1.1067,  1.5748, -0.8898,  0.5813],
                      [ 0.8967, -0.5414, -6.2293, -1.2310],
                      [ 0.4517, -0.2705, -0.2390,  0.0419],
                      [ 0.0118, -0.2964, -0.5022, -0.3177],
                      [-0.0810, -0.0312,  0.1967, -0.5733],
                      [ 0.3461,  0.4033,  0.6180, -0.0125],
                      [ 0.5193,  0.1527,  0.3628,  0.2688],
                      [ 0.0599,  0.1889, -0.6775,  0.0249],
                      [ 0.3614,  0.2492, -0.4951, -0.0232],
                      [-0.1383,  0.6677,  6.0743,  1.1747],
                      [ 0.8480,  0.9523,  4.7769,  1.4152],
                      [ 0.3134,  0.4275, -0.1514, -0.1056],
                      [ 0.6302,  0.2341,  0.0317,  0.0625],
                      [-0.1284, -0.0282,  0.0093,  0.2146],
                      [ 1.0972, -0.3998, -0.7634, -0.7361],
                      [-0.1000, -0.2719, -1.1902, -0.5305],
           

In [7]:
### Coherence classifier

#agent.model.get_weights()

# Define a simple GCN model
from torch_geometric.data import Data
class GCN(torch.nn.Module):
    def __init__(self, data):
        super(GCN, self).__init__()
        # Define the GCN layers
        self.conv1 = GCNConv(data.num_node_features, 4)  # Input features to hidden
        self.conv2 = GCNConv(4, 2)  # Hidden to output features
        self.data = data

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        # Pass data through the first GCN layer, then apply ReLU
        x = torch.relu(self.conv1(x, edge_index))
        # Pass data through the second GCN layer
        x = self.conv2(x, edge_index)
        return x
    
    
def nn_to_data(model: nn.Module) -> Data:
    edges = []
    node_features = []

    # Counter for global neuron index
    global_neuron_index = 0

    # Iterate over each layer in the network
    base = next(model.children())
    if isinstance(base, nn.Sequential):
        layers = base.children()
    else:
        layers = model.children()

    if isinstance(base, nn.Sequential):
        input_dim = base[0].in_features
        node_features = np.zeros(input_dim)

    for layer in layers:
        if isinstance(layer, nn.Linear):
            # Update edges based on the weight matrix
            for i in range(layer.weight.shape[1]):  # Input neurons
                for j in range(layer.weight.shape[0]):  # Output neurons
                    edges.append((global_neuron_index + i, global_neuron_index + layer.weight.shape[1] + j))
            
            # Update node features (e.g., biases)
            extension = layer.bias.detach().numpy()
            node_features = np.append(node_features, extension)
            
            # Update the global neuron index
            global_neuron_index += layer.weight.shape[1]

    # Convert lists to PyTorch tensors
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    x = torch.tensor(node_features, dtype=torch.float).view(-1, 1)
    return Data(x=x, edge_index=edge_index)

data = nn_to_data(agent.model)
gcn = GCN(data)
# data.x.shape, data.edge_index.shape
# print(data.x)

#Debug
out_of_bounds = data.edge_index >= data.x.shape[0]
if out_of_bounds.any():
    print("Out-of-bounds indices found at locations:")
    print(data.edge_index[:, out_of_bounds.any(dim=0)])

In [8]:
# Dataset generation
env = gym.make(env_name)
def deterministic_random(*args, lb = -1, ub = 1):
    unique_seed = f"{args}".encode("utf-8")
    random.seed(unique_seed)
    return random.uniform(lb, ub)

NUM_TRAIN_R_FUNCS = 100
NUM_EPS_TRAIN_R = 100
URS_r_funcs = [lambda s, a, ns, d: deterministic_random(s, a, ns, d) for _ in range(NUM_TRAIN_R_FUNCS)]
URS_agents = [train_dqn(env_name = env_name, 
                        episodes=NUM_EPS_TRAIN_R, reward_function=r_func) for r_func in URS_r_funcs]

In [14]:
UPS_agents = [DQNAgent(env.observation_space.shape[0], env.action_space.n) for _ in range(NUM_TRAIN_R_FUNCS)]

import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool, GCNConv

class GraphLevelGCN(torch.nn.Module):
    def __init__(self, num_node_features):
        super(GraphLevelGCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 16)
        self.conv2 = GCNConv(16, 16)
        self.linear = torch.nn.Linear(16, 1)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        
        # Aggregate node features to graph-level features
        x = global_mean_pool(x, batch)
        
        # Make a binary classification prediction
        x = self.linear(x)
        return torch.sigmoid(x)

# Training loop
URS_data = [nn_to_data(agent.model) for agent in URS_agents]
print(URS_data[0].x.shape)
UPS_data = [nn_to_data(agent.model) for agent in UPS_agents]
assert URS_data[0].x.shape == UPS_data[0].x.shape
indices = np.random.permutation(len(URS_data) + len(UPS_data))

data = [URS_data[i] if i < len(URS_data) else UPS_data[i - len(URS_data)] for i in indices]
for i in range(len(data)):
    data[i].y = 1.0 if indices[i] < len(URS_data) else 0.0 # Binary labels for each node; 1 = URS, 0 = UPS

train_data_ratio = 0.8
train_data, test_data = data[:int(train_data_ratio * len(data))], data[int(train_data_ratio * len(data)):]
# Loss and optimizer
num_node_features = 1 # just the bias term
model = GraphLevelGCN(num_node_features)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

epochs = 20
# Set the number of epochs to wait for early stopping
patience = 3
# Initialize variables for early stopping
best_loss = float('inf')
epochs_without_improvement = 0

for epoch in range(epochs):
    avg_train_loss = 0
    for datapt in train_data:
        model.train()
        optimizer.zero_grad()

        # print(f"datapt.x shape: {datapt.x.shape}")  # Should be [num_nodes, num_node_features]
        # print(f"datapt.edge_index shape: {datapt.edge_index.shape}")  # Should be [2, num_edges]
        out = model.forward(datapt)
        # print(out.size())
        # print(torch.tensor([[datapt.y]]).size())
        loss = criterion(out, torch.tensor([[datapt.y]]))  # Adjust shape as necessary
        loss.backward()
        optimizer.step()
        avg_train_loss += loss.item()
    avg_train_loss /= len(train_data)

    avg_test_loss = 0
    for datapt in test_data:
        model.eval()
        with torch.no_grad():
            out = model.forward(datapt)
            loss = criterion(out, torch.tensor([[datapt.y]]))
            avg_test_loss += loss.item()
    avg_test_loss /= len(test_data)
    
    print(f'Epoch {epoch+1}: Average Train Loss: {avg_train_loss}, Average Test Loss: {avg_test_loss}')
    
    # Early Stopping
    if avg_test_loss < best_loss:
        best_loss = avg_test_loss
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break

torch.Size([134, 1])
Epoch 1: Average Train Loss: 0.40541577302701626, Average Test Loss: 0.0537006069553172
Epoch 2: Average Train Loss: 0.07264251570584701, Average Test Loss: 0.021336334023372316
Epoch 3: Average Train Loss: 0.07250669127260512, Average Test Loss: 0.011036095714824867
Epoch 4: Average Train Loss: 0.0571737502666406, Average Test Loss: 0.011023389485248104
Epoch 5: Average Train Loss: 0.054839500761868276, Average Test Loss: 0.008242949860520455
Epoch 6: Average Train Loss: 0.05347820636596641, Average Test Loss: 0.0062461225085147735
Epoch 7: Average Train Loss: 0.050224280575129045, Average Test Loss: 0.004865015919730098
Epoch 8: Average Train Loss: 0.04852196569670415, Average Test Loss: 0.003984942406126279
Epoch 9: Average Train Loss: 0.04454295151290797, Average Test Loss: 0.003437603973772241
Epoch 10: Average Train Loss: 0.0442975319416596, Average Test Loss: 0.0031744233318995894
Epoch 11: Average Train Loss: 0.04262183145602867, Average Test Loss: 0.002960

In [15]:
# Test GCN model on a "more powerful" NN
print(model.forward(URS_data[0]))
print(model.forward(UPS_data[0]))
print(model.forward(nn_to_data(agent.model)))

tensor([[0.9998]], grad_fn=<SigmoidBackward0>)
tensor([[0.0042]], grad_fn=<SigmoidBackward0>)
tensor([[1.]], grad_fn=<SigmoidBackward0>)
