In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random

from libs.srenv import SREnv
from agents.rlagent import DQNAgent, ReplayBuffer

In [2]:
library = {
    '+': 2,
    '-': 2,
    '*': 2,
    '/': 2,
    # 'sin': 1,
    # 'cos': 1,
    'C': 0  # Placeholder for constants
}

# Create your data and target tensors
n_samples = 1000
n_vars = 1

for i in range(n_vars):
    var_name = f'X{i}'
    library[var_name] = 0

# X0 = torch.randn(-10, 10, n_samples)
# X1 = torch.linspace(-10, 10, n_samples)
data = torch.randn([n_vars, n_samples])  # Shape: (n_vars, n_samples)
target = 2 * data[0] + 1

# Initialize the environment
max_depth = 10
env = SREnv(library=library, data=data, target=target, max_depth=max_depth)

In [3]:
print(target.shape)

torch.Size([1000])


In [4]:
# Define vocabulary
vocab = list(library.keys()) + ['PAD']
symbol_to_index = {symbol: idx for idx, symbol in enumerate(vocab)}
index_to_symbol = {idx: symbol for symbol, idx in symbol_to_index.items()}
vocab_size = len(vocab)

# Maximum sequence length
max_seq_length = max_depth

def encode_state(state):
    # Convert symbols to indices
    state_indices = [symbol_to_index[symbol] for symbol in state]
    # Pad sequence
    if len(state_indices) < max_seq_length:
        state_indices += [symbol_to_index['PAD']] * (max_seq_length - len(state_indices))
    else:
        state_indices = state_indices[:max_seq_length]
    return torch.tensor(state_indices, dtype=torch.long)


In [5]:
action_symbols = list(library.keys())
action_size = len(action_symbols)
symbol_to_action_idx = {symbol: idx for idx, symbol in enumerate(action_symbols)}

In [6]:
print(action_symbols)

['+', '-', '*', '/', 'C', 'X0']


In [7]:
embedding_dim = 128
hidden_dim = 256
num_episodes = 600
batch_size = 250
gamma = 0.99
epsilon_start = 1.0
epsilon_end = 0.1
epsilon_decay = 0.95
target_update = num_episodes / 10
memory_capacity = max_seq_length * num_episodes

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
agent = DQNAgent(vocab_size, embedding_dim, hidden_dim, action_size)
target_agent = DQNAgent(vocab_size, embedding_dim, hidden_dim, action_size)
target_agent.load_state_dict(agent.state_dict())
target_agent.train()
agent.train()

optimizer = optim.Adam(agent.parameters(), lr=1e-4)
criterion = nn.MSELoss()
memory = ReplayBuffer(memory_capacity)

epsilon = epsilon_start

for episode in range(num_episodes):
    state_symbols = env.reset()
    state_encoded = encode_state(state_symbols)  # Shape: (seq_length,)
    done = False
    total_reward = 0
    i = 0

    while not done and i < max_seq_length:
        # Select action
        action_idx = agent.act(state_encoded, epsilon)
        action_symbol = action_symbols[action_idx]
        
        try:
            next_state_symbols, reward, done = env.step(action_symbol)
        except ValueError as e:
            reward = 0
            done = True
            next_state_symbols = state_symbols  # Remain in the same state

        next_state_encoded = encode_state(next_state_symbols)
        total_reward += reward

        # Store transition in memory
        memory.push(
            state_encoded,
            action_idx,
            reward,
            next_state_encoded,
            done
        )

        state_encoded = next_state_encoded
        state_symbols = next_state_symbols

        # Experience replay
        if len(memory) >= batch_size:
            # Sample from memory
            states_batch, actions_batch, rewards_batch, next_states_batch, dones_batch = memory.sample(batch_size)
            
            # Compute current Q-values
            q_values = agent(states_batch)
            q_values = q_values.gather(1, actions_batch.unsqueeze(1)).squeeze(1)
            
            # Compute target Q-values
            with torch.no_grad():
                next_q_values = target_agent(next_states_batch).max(dim=1)[0]
                target_q_values = rewards_batch + gamma * next_q_values * (1 - dones_batch)
            
            # Compute loss
            loss = criterion(q_values, target_q_values)
            
            # Optimize the model
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        i += 1

    memory.prioritise(done)
    
    # Decay epsilon
    if epsilon > epsilon_end:
        epsilon *= epsilon_decay

    # Update target network
    if episode % target_update == 0:
        target_agent.load_state_dict(agent.state_dict())

    if done:
        print(f"Episode {episode} completed, Total Reward: {total_reward}")
    else:
        print(f"Episode {episode} failed, Total Reward: {total_reward}")

Episode 0 failed, Total Reward: 0
Episode 1 failed, Total Reward: 0
Episode 2 failed, Total Reward: 0
Episode 3 completed, Total Reward: 0.32000935077667236
Episode 4 failed, Total Reward: 0
Episode 5 failed, Total Reward: 0
Episode 6 failed, Total Reward: 0
Episode 7 failed, Total Reward: 0
Episode 8 failed, Total Reward: 0
Episode 9 failed, Total Reward: 0
Episode 10 failed, Total Reward: 0
Episode 11 failed, Total Reward: 0
Episode 12 completed, Total Reward: 0.18622447550296783
Episode 13 completed, Total Reward: 0.32000935077667236
Episode 14 failed, Total Reward: 0
Episode 15 completed, Total Reward: 0.18622447550296783
Episode 16 failed, Total Reward: 0
Episode 17 failed, Total Reward: 0
Episode 18 failed, Total Reward: 0
Episode 19 failed, Total Reward: 0
Episode 20 failed, Total Reward: 0
Episode 21 failed, Total Reward: 0
Episode 22 failed, Total Reward: 0
Episode 23 completed, Total Reward: 0.18622447550296783
Episode 24 failed, Total Reward: 0
Episode 25 failed, Total Rewar

In [10]:
# Testing the agent
state_symbols = env.reset()
state_encoded = encode_state(state_symbols)
done = False
total_reward = 0
expression_actions = []
i = 0

agent.eval()
target_agent.eval()

while not done:
    with torch.no_grad():
        q_values = agent(state_encoded.unsqueeze(0))  # Add batch dimension
        action_idx = torch.argmax(q_values).item()
    action_symbol = action_symbols[action_idx]
    expression_actions.append(action_symbol)
    
    try:
        next_state_symbols, reward, done = env.step(action_symbol)
    except ValueError as e:
        reward = -1.0
        done = True
        next_state_symbols = state_symbols

    next_state_encoded = encode_state(next_state_symbols)
    total_reward += reward
    state_encoded = next_state_encoded
    state_symbols = next_state_symbols

    if i == max_seq_length:
        state_symbols = env.reset()
        state_encoded = encode_state(state_symbols)
        total_reward = 0
        expression_actions = []
        i = 0
        print('restarting...')
    else:
        i += 1


n_const = env.expression.n_constants
const_count = 0
for idx, token in enumerate(expression_actions):
    if const_count == n_const:
        break
    if token == 'C':
        const_val = env.expression.optimized_constants[const_count].item()
        expression_actions[idx] = str(const_val)
        const_count += 1

print(f"Constructed Expression: {' '.join(expression_actions)}")
print(f"Test Total Reward: {total_reward}")

Constructed Expression: + 1.0 + X0 X0
Test Total Reward: 1.0
