In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import numpy as np

from libs.srenv import SREnv
from agents.rlagent import DQNAgent, ReplayBuffer

In [14]:
a = torch.ones(5)
b = torch.ones(5) * 2

print(torch.stack([a, b]))

tensor([[1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2.]])


In [16]:
library = {
    '+': 2,
    '-': 2,
    '*': 2,
    '/': 2,
    # 'sin': 1,
    # 'cos': 1,
    'C': 0  # Placeholder for constants
}

# Create your data and target tensors
n_samples = 1000
n_vars = 2

for i in range(n_vars):
    var_name = f'X{i}'
    library[var_name] = 0

# X0 = torch.randn(-10, 10, n_samples)
# X1 = torch.linspace(-10, 10, n_samples)
diff = [torch.ones(n_samples) * (i + 1) for i in range(n_vars)]
data = torch.randn([n_vars, n_samples]) + torch.stack(diff)# Shape: (n_vars, n_samples)
target = 2 * data[0] + data[1]

# Initialize the environment
max_depth = 10
env = SREnv(library=library, data=data, target=target, max_depth=max_depth)

In [18]:
print(target.shape)
print(data.shape)

torch.Size([1000])
torch.Size([2, 1000])


In [4]:
# Define vocabulary
vocab = list(library.keys()) + ['PAD']
symbol_to_index = {symbol: idx for idx, symbol in enumerate(vocab)}
index_to_symbol = {idx: symbol for symbol, idx in symbol_to_index.items()}
vocab_size = len(vocab)

# Maximum sequence length
max_seq_length = max_depth

def encode_state(state):
    # Convert symbols to indices
    state_indices = [symbol_to_index[symbol] for symbol in state]
    # Pad sequence
    if len(state_indices) < max_seq_length:
        state_indices += [symbol_to_index['PAD']] * (max_seq_length - len(state_indices))
    else:
        state_indices = state_indices[:max_seq_length]
    return torch.tensor(state_indices, dtype=torch.long)


In [5]:
action_symbols = list(library.keys())
action_size = len(action_symbols)
symbol_to_action_idx = {symbol: idx for idx, symbol in enumerate(action_symbols)}

In [6]:
print(action_symbols)

['+', '-', '*', '/', 'C', 'X0']


In [7]:
embedding_dim = 128
hidden_dim = 256
num_batches = 100
num_episodes_per_batch = 10
batch_quantile = 0.15
batch_size = 250
gamma = 0.99
epsilon_start = 1.0
epsilon_end = 0.1
epsilon_decay = 0.95
target_update = num_batches / 10
memory_capacity = max_seq_length * num_episodes_per_batch * num_batches
batch_eval = 5

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:
agent = DQNAgent(vocab_size, embedding_dim, hidden_dim, action_size)
target_agent = DQNAgent(vocab_size, embedding_dim, hidden_dim, action_size)
target_agent.load_state_dict(agent.state_dict())
target_agent.eval()
agent.train()

optimizer = optim.Adam(agent.parameters(), lr=1e-4)
criterion = nn.MSELoss()
memory = ReplayBuffer(memory_capacity)

epsilon = epsilon_start

for batch in range(num_batches):
    episodes = []
    for episode in range(num_episodes_per_batch):
        state_symbols = env.reset()
        state_encoded = encode_state(state_symbols)  # Shape: (seq_length,)
        done = False
        total_reward = 0
        transitions = []
        i = 0

        while not done and i < max_seq_length:
            # Select action
            action_idx = agent.act(state_encoded, epsilon)
            action_symbol = action_symbols[action_idx]

            try:
                next_state_symbols, reward, done = env.step(action_symbol)
            except ValueError as e:
                reward = 0
                done = True
                next_state_symbols = state_symbols  # Remain in the same state

            next_state_encoded = encode_state(next_state_symbols)
            total_reward += reward

            # Store transition in memory
            transitions.append((
                state_encoded,
                action_idx,
                reward,
                next_state_encoded,
                done
            ))

            state_encoded = next_state_encoded
            state_symbols = next_state_symbols
            
            i += 1

        if not done:
            transitions = [
                (
                    t[0],
                    t[1],
                    -1,
                    t[3],
                    t[4]
                )
                for t in transitions
            ]
            total_reward = -1

        episodes.append((transitions, total_reward))

        total_rewards = [episode[1] for episode in episodes]

        threshold = np.quantile(total_rewards, 1 - batch_quantile)

        top_episodes = [episode for episode in episodes if episode[1] >= threshold]

        for episode_transitions, _ in top_episodes:
            for transition in episode_transitions:
                memory.push(*transition)

        # Experience replay
        if len(memory) >= batch_size:
            # Sample from memory
            states_batch, actions_batch, rewards_batch, next_states_batch, dones_batch = memory.sample(batch_size)
            # Compute current Q-values
            q_values = agent(states_batch)
            q_values = q_values.gather(1, actions_batch.unsqueeze(1)).squeeze(1)
            # Compute target Q-values
            with torch.no_grad():
                next_q_values = target_agent(next_states_batch).max(dim=1)[0]
                target_q_values = rewards_batch + gamma * next_q_values * (1 - dones_batch)
            # Compute loss
            loss = criterion(q_values, target_q_values)
            # Optimize the model
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Decay epsilon
        if epsilon > epsilon_end:
            epsilon *= epsilon_decay

        # Update target network
        if episode % target_update == 0:
            target_agent.load_state_dict(agent.state_dict())


        if done:
            print(f"Episode {batch}.{episode} completed, Total Reward: {total_reward}")
        else:
            print(f"Episode {batch}.{episode} failed, Total Reward: {total_reward}")

    if batch % batch_eval == 0:
        # Testing the agent
        print('---------------------')
        print('Evaluating...')
        print('---------------------')
        state_symbols = env.reset()
        state_encoded = encode_state(state_symbols)
        done = False
        total_reward = 0
        expression_actions = []
        i = 0
        agent.eval()
        
        while not done and i < max_seq_length:
            with torch.no_grad():
                q_values = agent(state_encoded.unsqueeze(0))  # Add batch dimension
                action_idx = torch.argmax(q_values).item()
            action_symbol = action_symbols[action_idx]
            expression_actions.append(action_symbol)
            try:
                next_state_symbols, reward, done = env.step(action_symbol)
            except ValueError as e:
                reward = -1.0
                done = True
                next_state_symbols = state_symbols
            next_state_encoded = encode_state(next_state_symbols)
            total_reward += reward
            state_encoded = next_state_encoded
            state_symbols = next_state_symbols

            i += 1

        if done and total_reward == 1:
            print('Found expression! Stopping early...')
            break
        else:
            agent.train()

Episode 0.0 completed, Total Reward: 0.9585849046707153
Episode 0.1 failed, Total Reward: -1
Episode 0.2 completed, Total Reward: 0.9585849046707153
Episode 0.3 failed, Total Reward: -1
Episode 0.4 completed, Total Reward: 0.48955413699150085
Episode 0.5 failed, Total Reward: -1
Episode 0.6 completed, Total Reward: 1.0
Episode 0.7 completed, Total Reward: 0.9585849046707153
Episode 0.8 completed, Total Reward: 0.9585849046707153
Episode 0.9 completed, Total Reward: 1.0
---------------------
Evaluating...
---------------------
Episode 1.0 failed, Total Reward: -1
Episode 1.1 failed, Total Reward: -1
Episode 1.2 completed, Total Reward: 0.9585849046707153
Episode 1.3 completed, Total Reward: 0.9585849046707153
Episode 1.4 completed, Total Reward: 0.38318127393722534
Episode 1.5 completed, Total Reward: 0.9585849046707153
Episode 1.6 completed, Total Reward: 0.5912706851959229
Episode 1.7 completed, Total Reward: 0.9585849046707153
Episode 1.8 completed, Total Reward: 0.9520884156227112
E

In [9]:
# Testing the agent
state_symbols = env.reset()
state_encoded = encode_state(state_symbols)
done = False
total_reward = 0
expression_actions = []
max_restart = 100
r = 0
i = 0

agent.eval()
target_agent.eval()

while not done and r < max_restart:
    with torch.no_grad():
        q_values = agent(state_encoded.unsqueeze(0))  # Add batch dimension
        action_idx = torch.argmax(q_values).item()
    action_symbol = action_symbols[action_idx]
    expression_actions.append(action_symbol)
    
    try:
        next_state_symbols, reward, done = env.step(action_symbol)
    except ValueError as e:
        reward = -1.0
        done = True
        next_state_symbols = state_symbols

    next_state_encoded = encode_state(next_state_symbols)
    total_reward += reward
    state_encoded = next_state_encoded
    state_symbols = next_state_symbols

    if i == max_seq_length:
        state_symbols = env.reset()
        state_encoded = encode_state(state_symbols)
        total_reward = 0
        expression_actions = []
        i = 0
        r += 1
        print('restarting...')
    else:
        i += 1


n_const = env.expression.n_constants
const_count = 0
for idx, token in enumerate(expression_actions):
    if const_count == n_const:
        break
    if token == 'C':
        const_val = env.expression.optimized_constants[const_count].item()
        expression_actions[idx] = str(round(const_val, 3))
        const_count += 1

print(f"Constructed Expression: {' '.join(expression_actions)}")
print(f"Test Total Reward: {total_reward}")

Constructed Expression: * * * X0 -0.317 0.327 -1.929
Test Total Reward: 1.0
