In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from libs.srenv import SREnv
from agents.rlagent import SymbolicRegressionAgent

In [7]:
lib = {
    '+': 2,
    '-': 2,
    '/': 2,
    '*': 2,
    'cos': 1,
    'sin': 1,
    'X0': 0,
    'X1': 0,
    'C': 0
}

X = torch.randn([2, 5])
y = 2.2 - X[0] / 11 + 7 * torch.cos(X[1])

env = SREnv(library=lib, data=X, target=y)

In [8]:
# Define vocabulary and action space
vocab = list(lib.keys()) + ['PAD']  # Include 'PAD' for padding
vocab_size = len(vocab)
action_size = len(lib)  # Actions correspond to symbols in the lib

# Map symbols to indices
symbol_to_index = {symbol: idx for idx, symbol in enumerate(vocab)}
index_to_symbol = {idx: symbol for symbol, idx in symbol_to_index.items()}

# Hyperparameters
embedding_dim = 128
hidden_dim = 256

agent = SymbolicRegressionAgent(vocab_size, embedding_dim, hidden_dim, action_size)
agent_optimizer = torch.optim.Adam(agent.parameters(), lr=1e-4)


In [9]:
num_episodes = 1000
max_steps_per_episode = 50  # Maximum number of steps per episode
gamma = 0.99  # Discount factor
epsilon = 0.2  # PPO clip parameter
entropy_coeff = 0.01  # Coefficient for entropy regularization
value_coeff = 0.5  # Coefficient for value loss

In [10]:
for episode in range(num_episodes):
    state = env.reset()  # Initial state: list of symbols (encoded tree)
    state_indices = [symbol_to_index[symbol] for symbol in state]
    episode_reward = 0
    done = False
    step = 0
    
    # Lists to store experiences
    states = []
    actions = []
    rewards = []
    dones = []
    log_probs = []
    values = []
    action_masks = []
    
    while not done and step < max_steps_per_episode:
        # Convert state to tensor
        state_tensor = torch.tensor([state_indices], dtype=torch.long)  # Shape: (1, seq_length)
        
        # Generate action mask
        valid_actions = env.tree.library.keys()
        action_mask = torch.zeros(action_size)
        for action in valid_actions:
            if action in env.tree.library:
                action_idx = list(env.tree.library.keys()).index(action)
                action_mask[action_idx] = 1
        action_mask = action_mask.unsqueeze(0)  # Shape: (1, action_size)
        
        # Agent predicts action probabilities and value
        action_probs, value = agent(state_tensor, action_mask)
        
        # Sample action
        dist = torch.distributions.Categorical(action_probs)
        action = dist.sample()
        action_log_prob = dist.log_prob(action)
        
        # Map action index to symbol
        action_symbol = list(env.tree.library.keys())[action.item()]
        
        # Take action in the environment
        next_state, reward, done = env.step(action_symbol)
        next_state_indices = [symbol_to_index[symbol] for symbol in next_state]
        
        # Store experiences
        states.append(state_indices)
        actions.append(action)
        rewards.append(torch.tensor([reward], dtype=torch.float))
        dones.append(torch.tensor([done], dtype=torch.float))
        log_probs.append(action_log_prob)
        values.append(value)
        action_masks.append(action_mask)
        
        # Update state
        state_indices = next_state_indices
        episode_reward += reward
        step += 1
    
    # Compute returns and advantages
    returns = []
    advantages = []
    Gt = 0
    for i in reversed(range(len(rewards))):
        Gt = rewards[i] + gamma * Gt * (1 - dones[i])
        returns.insert(0, Gt)
        advantage = Gt - values[i].detach()
        advantages.insert(0, advantage)
    
    returns = torch.cat(returns)
    advantages = torch.cat(advantages)
    log_probs = torch.cat(log_probs)
    values = torch.cat(values)
    
    # Convert lists to tensors
    actions_tensor = torch.cat(actions)
    action_masks_tensor = torch.cat(action_masks)
    
    # Update policy and value networks
    agent_optimizer.zero_grad()
    
    # Recompute action probabilities and values
    states_tensor = torch.tensor(states, dtype=torch.long)
    action_masks_tensor = action_masks_tensor
    new_action_probs, new_values = agent(states_tensor, action_masks_tensor)
    new_dist = torch.distributions.Categorical(new_action_probs)
    new_log_probs = new_dist.log_prob(actions_tensor)
    
    # Compute ratio (new / old policy)
    ratio = torch.exp(new_log_probs - log_probs.detach())
    
    # Compute surrogate loss
    surr1 = ratio * advantages
    surr2 = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * advantages
    policy_loss = -torch.min(surr1, surr2).mean()
    
    # Compute value loss
    value_loss = F.mse_loss(new_values.squeeze(-1), returns)
    
    # Compute entropy bonus
    entropy = new_dist.entropy().mean()
    
    # Total loss
    loss = policy_loss + value_coeff * value_loss - entropy_coeff * entropy
    loss.backward()
    agent_optimizer.step()
    
    # Logging
    print(f"Episode {episode+1}/{num_episodes}, Reward: {episode_reward}, Loss: {loss.item()}")


  value_loss = F.mse_loss(new_values.squeeze(-1), returns)


Episode 1/1000, Reward: 0.013888388872146606, Loss: -0.0329558402299881
Episode 2/1000, Reward: 0.013888388872146606, Loss: 0.01258053444325924


ValueError: expected sequence of length 1023 at dim 1 (got 2)