In [1]:
import gym
import numpy as np
import random
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch_geometric.nn.conv import GCNConv

Setup:
- Starting point: just try to train classifier on RL policies

In [2]:
### DQN implementation

# Define the Q-network
class DQN(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super(DQN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

    def forward(self, x):
        return self.fc(x)
    
    def get_weights(self):
        return self.state_dict()

# Experience Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        state = np.expand_dims(state, 0)
        next_state = np.expand_dims(next_state, 0)
            
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.concatenate(state), action, reward, np.concatenate(next_state), done

    def __len__(self):
        return len(self.buffer)

# DQN Agent
class DQNAgent:
    def __init__(self, state_dim, action_dim, hidden_dim=64, lr=1e-3, batch_size=64, gamma=0.99, replay_size=1000):
        self.model = DQN(state_dim, action_dim, hidden_dim)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.replay_buffer = ReplayBuffer(replay_size)
        self.batch_size = batch_size
        self.gamma = gamma
        self.action_dim = action_dim
    
    def update(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        state, action, reward, next_state, done = self.replay_buffer.sample(self.batch_size)
        
        state = torch.FloatTensor(state)
        next_state = torch.FloatTensor(next_state)
        action = torch.LongTensor(action)
        reward = torch.FloatTensor(reward)
        done = torch.FloatTensor(done)

        q_values = self.model(state)
        next_q_values = self.model(next_state)
        
        q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
        next_q_value = next_q_values.max(1)[0]
        expected_q_value = reward + self.gamma * next_q_value * (1 - done)
        
        loss = nn.MSELoss()(q_value, expected_q_value.detach())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def act(self, state, epsilon):
        if random.random() > epsilon:
            state = torch.FloatTensor(np.expand_dims(state, 0))
            q_value = self.model(state)
            action = q_value.max(1)[1].item()
        else:
            action = random.randrange(self.action_dim)
        return action

In [3]:
def train_dqn(env_name="CartPole-v1", episodes=500, epsilon_start=1.0, epsilon_final=0.01, 
              epsilon_decay=500):
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    agent = DQNAgent(state_dim, action_dim)
    
    epsilon_by_frame = lambda frame_idx: epsilon_final + (epsilon_start - epsilon_final) * np.exp(-1. * frame_idx / epsilon_decay)
    
    for episode in range(episodes):
        state = env.reset()
        episode_reward = 0
        while True:
            epsilon = epsilon_by_frame(episode)
            action = agent.act(state, epsilon)
            next_state, reward, done, _ = env.step(action)
            agent.replay_buffer.push(state, action, reward, next_state, done)
            state = next_state
            episode_reward += reward
            
            agent.update()
            
            if done:
                break
            # print(f"Episode: {episode+1}, Total reward: {episode_reward}, Epsilon: {epsilon:.2f}")

        # Optional: Render the environment to visualize training progress
        # if episode % 100 == 0:
        #     render_env(env, agent)

    env.close()
    return agent

# Optional: Function to render the environment with the current policy
def render_env(env, agent):
    state = env.reset()
    done = False
    while not done:
        action = agent.act(state, 0)  # Using 0 epsilon for greedy action selection
        # print(env.step(action))
        next_state, reward, done, _ = env.step(action)
        env.render()
        state = next_state

In [4]:
def test_dqn(env, agent, episodes=10):
    print(f"Maximum reward: {env.spec.reward_threshold}")
    for episode in range(episodes):
        # if episode == 0:
        #     render_env(env, agent)
        state = env.reset()
        episode_reward = 0
        done = False
        while not done:
            action = agent.act(state, 0)  # Using 0 epsilon for greedy action selection
            next_state, reward, done, _ = env.step(action)
            episode_reward += reward
            state = next_state
        print(f"Episode: {episode+1}, Total reward: {episode_reward}")

In [5]:
agent = train_dqn(episodes = 50)

In [6]:
test_dqn(gym.make("CartPole-v1"), agent)
agent.model.get_weights()

Maximum reward: 475.0
Episode: 1, Total reward: 65.0
Episode: 2, Total reward: 249.0
Episode: 3, Total reward: 200.0
Episode: 4, Total reward: 340.0
Episode: 5, Total reward: 500.0
Episode: 6, Total reward: 206.0
Episode: 7, Total reward: 500.0
Episode: 8, Total reward: 88.0
Episode: 9, Total reward: 132.0
Episode: 10, Total reward: 303.0


OrderedDict([('fc.0.weight',
              tensor([[ 8.0815e-01,  2.2713e-01, -1.5328e+00, -5.7343e-01],
                      [ 5.3673e-01, -1.4522e-01, -4.6750e-01, -6.8723e-01],
                      [-9.3924e-02, -1.8958e-02,  3.8947e-01, -3.0018e-01],
                      [-1.5461e-01, -3.5016e-01, -9.7916e-02, -1.6390e-01],
                      [-5.3706e-01, -9.6581e-01,  4.8666e-01,  8.0225e-01],
                      [ 1.8764e-01,  8.2671e-02, -1.0861e+00, -6.7551e-01],
                      [-3.3833e-02, -3.0430e-01,  4.4604e-01, -2.2593e-01],
                      [-1.2941e-03, -2.2370e-01,  2.2220e-01,  1.1660e-01],
                      [-8.6939e-04,  2.2074e-01,  4.9737e-01,  1.0742e-02],
                      [ 7.9584e-01,  3.4847e-01, -1.1259e+00, -4.8410e-01],
                      [ 3.1187e-01,  1.9612e-01, -1.3736e+00, -2.8820e-01],
                      [ 5.9451e-01, -3.2977e-01, -2.1979e-01, -3.7078e-02],
                      [ 9.7603e-01,  3.8013e-01, -7.9632e-0

In [10]:
### Coherence classifier

#agent.model.get_weights()

# Define a simple GCN model
from torch_geometric.data import Data
class GCN(torch.nn.Module):
    def __init__(self, data):
        super(GCN, self).__init__()
        # Define the GCN layers
        self.conv1 = GCNConv(data.num_node_features, 4)  # Input features to hidden
        self.conv2 = GCNConv(4, 2)  # Hidden to output features
        self.data = data

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        # Pass data through the first GCN layer, then apply ReLU
        x = torch.relu(self.conv1(x, edge_index))
        # Pass data through the second GCN layer
        x = self.conv2(x, edge_index)
        return x
    
    
def nn_to_data(model: nn.Module) -> Data:
    edges = []
    node_features = []

    # Counter for global neuron index
    global_neuron_index = 0

    # Iterate over each layer in the network
    base = next(model.children())
    if isinstance(base, nn.Sequential):
        layers = base.children()
    else:
        layers = model.children()

    if isinstance(base, nn.Sequential):
        input_dim = base[0].in_features
        node_features = np.zeros(input_dim)

    for layer in layers:
        if isinstance(layer, nn.Linear):
            # Update edges based on the weight matrix
            for i in range(layer.weight.shape[1]):  # Input neurons
                for j in range(layer.weight.shape[0]):  # Output neurons
                    edges.append((global_neuron_index + i, global_neuron_index + layer.weight.shape[1] + j))
            
            # Update node features (e.g., biases)
            extension = layer.bias.detach().numpy()
            node_features = np.append(node_features, extension)
            
            # Update the global neuron index
            global_neuron_index += layer.weight.shape[1]

    # Convert lists to PyTorch tensors
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    x = torch.tensor(node_features, dtype=torch.float).view(-1, 1)
    return Data(x=x, edge_index=edge_index)

data = nn_to_data(agent.model)
gcn = GCN(data)
# data.x.shape, data.edge_index.shape
# print(data.x)

#Debug
out_of_bounds = data.edge_index >= data.x.shape[0]
if out_of_bounds.any():
    print("Out-of-bounds indices found at locations:")
    print(data.edge_index[:, out_of_bounds.any(dim=0)])

In [15]:
# Dataset generation
random_r_funcs = 

# Loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(gcn.parameters(), lr=0.01)

# Training loop
epochs = 200
for epoch in range(epochs):
    gcn.train()
    optimizer.zero_grad()   # Clear gradients
    out = gcn.forward(data)       # Forward pass
    print(out.shape)

    answer = 0 #TODO: fill this out
    loss = criterion(out, answer)  # Compute loss
    loss.backward()         # Backward pass
    optimizer.step()        # Update weights

    # Optional: Print loss every X epochs
    if epoch % 10 == 0:
        print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item()}')

torch.Size([134, 2])


TypeError: cross_entropy_loss(): argument 'target' (position 2) must be Tensor, not int