In [None]:
!pip install optuna gymnasium stable-baselines3

Collecting optuna
  Downloading optuna-4.0.0-py3-none-any.whl.metadata (16 kB)
Collecting gymnasium
  Downloading gymnasium-0.29.1-py3-none-any.whl.metadata (10 kB)
Collecting stable-baselines3
  Downloading stable_baselines3-2.3.2-py3-none-any.whl.metadata (5.1 kB)
Collecting alembic>=1.5.0 (from optuna)
  Downloading alembic-1.13.3-py3-none-any.whl.metadata (7.4 kB)
Collecting colorlog (from optuna)
  Downloading colorlog-6.8.2-py3-none-any.whl.metadata (10 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Collecting Mako (from alembic>=1.5.0->optuna)
  Downloading Mako-1.3.5-py3-none-any.whl.metadata (2.9 kB)
Downloading optuna-4.0.0-py3-none-any.whl (362 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m362.8/362.8 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading gymnasium-0.29.1-py3-none-any.whl (953 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [

In [None]:
import optuna
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import PPO, A2C
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
from collections import deque
from tqdm import tqdm

class DeepSeaEnv(gym.Env):
    """
    Deep Sea Exploration Environment

    The agent moves through an NxN grid starting from the top left and trying to reach
    the bottom right. Moving right yields higher rewards, while moving left incurs penalties.
    """

    def __init__(self, N=5):
        super(DeepSeaEnv, self).__init__()
        self.N = N

        # Define action and observation spaces
        self.action_space = spaces.Discrete(2)  # 0: Left, 1: Right
        self.observation_space = spaces.Discrete(N * N)  # Flattened grid

        # Reward constants
        self.penalty = -0.01 / N
        self.final_reward = 1.0

        # Initialize state
        self.reset()

    def reset(self, seed=None, options=None):
        """Reset the environment to the initial state."""
        super().reset(seed=seed)
        self.state = (0, 0)  # Start at the top-left corner
        self.steps_taken = 0
        return self._get_obs(), {}  # Return flattened state and an empty info dict

    def _get_obs(self):
        """Convert 2D coordinates to a single integer."""
        row, col = self.state
        return row * self.N + col

    def step(self, action):
        """
        Take a step in the environment.

        Parameters:
            action (int): 0 for left, 1 for right.

        Returns:
            tuple: (obs, reward, terminated, truncated, info)
        """
        row, col = self.state

        # Determine new column based on action
        if action == 1:  # Right
            new_col = min(col + 1, self.N - 1)
            reward = self.penalty  # Small negative reward per step
        else:  # Left
            new_col = max(col - 1, 0)
            reward = 0  # No penalty for left

        # Move down one row
        new_row = row + 1

        self.state = (new_row, new_col)
        self.steps_taken += 1

        # Check if we have reached the terminal state
        terminated = self.steps_taken == self.N - 1
        truncated = False

        # If at the bottom-right corner and all actions were right
        if terminated and new_col == self.N - 1:
            reward += self.final_reward

        return self._get_obs(), reward, terminated, truncated, {}

    def render(self):
        """Render the environment (not implemented for simplicity)."""
        grid = np.zeros((self.N, self.N))
        row, col = self.state
        grid[row, col] = 1  # Mark the agent's position

        print(grid)

    def close(self):
        pass

# Register the environment
gym.register(
    id='DeepSea-v0',
    entry_point=DeepSeaEnv,
    kwargs={'N': 7},
)

My Algorithm

In [None]:
N=7

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def normalize_slices(tensor):
    return tensor / tensor.sum(dim=-1, keepdim=True)

def normalise_initial(counts):
    return counts / counts.sum()

def softmax_policy(policy_table):
    return torch.nn.functional.softmax(policy_table, dim=-1)

def state_to_index(state, env):
    if isinstance(env.observation_space, gym.spaces.MultiDiscrete):
        # Calculate the index for MultiDiscrete space
        index = 0
        for i, (s, n) in enumerate(zip(state, env.observation_space.nvec)):
            index += s * np.prod(env.observation_space.nvec[i+1:])
        return int(index)
    elif isinstance(env.observation_space, gym.spaces.Discrete):
        return state
    else:
        raise ValueError("Unsupported observation space type")

def get_num_states(env):
    if isinstance(env.observation_space, gym.spaces.MultiDiscrete):
        return np.prod(env.observation_space.nvec)
    elif isinstance(env.observation_space, gym.spaces.Discrete):
        return env.observation_space.n
    else:
        raise ValueError("Unsupported observation space type")

def sample_steps(env, policy, num_steps, max_steps_per_episode):
    num_states = get_num_states(env)
    num_actions = env.action_space.n
    trajectories = []
    initial_states = []
    transition_counts = torch.ones((num_states, num_actions, num_states), dtype=torch.int32, device=device)
    reward_total = torch.zeros((num_states, num_actions), device=device)
    reward_count = torch.zeros((num_states, num_actions), device=device)
    initial_state_count = torch.zeros(num_states, dtype=torch.int32, device=device)

    steps_taken = 0
    while steps_taken < num_steps:
        state, _ = env.reset()
        #state_idx = state_to_index(state, env)
        state_idx = state
        initial_state_count[state_idx] += 1
        initial_states.append(state_idx)
        trajectory = []

        for step in range(max_steps_per_episode):
            action = torch.multinomial(policy[state_idx].cpu(), 1).item()
            next_state, reward, done, _, _ = env.step(action)
            #next_state_idx = state_to_index(next_state, env)
            next_state_idx = next_state
            trajectory.append((state_idx, action, reward, next_state_idx))

            transition_counts[state_idx, action, next_state_idx] += 1
            reward_total[state_idx, action] += reward
            reward_count[state_idx, action] += 1

            steps_taken += 1
            if done or steps_taken >= num_steps:
                break
            state_idx = next_state_idx

        trajectories.append(trajectory)

    return transition_counts, reward_total, reward_count, initial_state_count, initial_states, trajectories, steps_taken

def process_trajectories(trajectories):
    states = []
    actions = []
    rewards = []
    next_states = []

    for trajectory in trajectories:
        for step in trajectory:
            states.append(step[0])
            actions.append(step[1])
            rewards.append(step[2])
            next_states.append(step[3])

    return (torch.tensor(states, device=device),
            torch.tensor(actions, device=device),
            torch.tensor(rewards, dtype=torch.float32, device=device),
            torch.tensor(next_states, device=device))

def compute_J_counting(env, policy, v, R, P, gamma=0.99):
    num_states = P.shape[0]
    P_a = P.permute(1, 0, 2)
    P_pi = torch.einsum('sa,ask->sk', policy, P_a)
    R_pi = torch.einsum('sa,sa->s', policy, R)

    J = v.unsqueeze(0) @ torch.linalg.solve(torch.eye(num_states, device=device) - gamma * P_pi, R_pi.unsqueeze(1))

    return J

def tabular_feature_map(total_states, total_actions, regularizer, policy, initial_states, current_states, current_actions, next_states, rewards, gamma):
    sample_size = len(current_states)
    latent_dim = total_states * total_actions
    initial_state_sample_size = len(initial_states)

    # Create X more efficiently
    X = torch.zeros(sample_size, latent_dim, device=device)
    indices = current_states * total_actions + current_actions
    X.scatter_(1, indices.unsqueeze(1), 1)

    Y = torch.zeros(sample_size, latent_dim, device=device)
    next_state_indices = next_states[:, None] * total_actions + torch.arange(total_actions, device=device)
    Y[torch.arange(sample_size, device=device)[:, None], next_state_indices] = policy[next_states]

    W = torch.zeros(latent_dim, device=device)
    initial_state_indices = torch.tensor(initial_states, device=device)[:, None] * total_actions + torch.arange(total_actions, device=device)
    W.index_add_(0, initial_state_indices.flatten(), policy[torch.tensor(initial_states, device=device)].flatten())
    W /= initial_state_sample_size

    # Compute C_lambda, D, and E in one go
    C_lambda = X.T @ X + regularizer * torch.eye(latent_dim, device=device)
    D = X.T @ Y
    E = X.T @ rewards.unsqueeze(1)

    # Solve linear systems
    A = torch.linalg.solve(C_lambda, E).T
    M = torch.linalg.solve(C_lambda, D).T

    # Compute J
    J = A @ torch.linalg.solve(torch.eye(latent_dim, device=device) - gamma * M, W)

    return J

class VectorizedAccumulatedData:
    def __init__(self, max_size=int(N*10000*0.15), device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.max_size = max_size
        self.device = device
        self.transition_counts = None
        self.reward_total = None
        self.reward_count = None
        self.initial_state_count = None
        self.initial_states = deque(maxlen=max_size)
        self.states = deque(maxlen=max_size)
        self.actions = deque(maxlen=max_size)
        self.rewards = deque(maxlen=max_size)
        self.next_states = deque(maxlen=max_size)
        self.total_steps = 0

    def update(self, transition_counts, reward_total, reward_count, initial_state_count, initial_states, trajectories, steps):
        # Update counts and totals
        if self.transition_counts is None:
            self.transition_counts = transition_counts.to(self.device)
            self.reward_total = reward_total.to(self.device)
            self.reward_count = reward_count.to(self.device)
            self.initial_state_count = initial_state_count.to(self.device)
        else:
            self.transition_counts += transition_counts.to(self.device)
            self.reward_total += reward_total.to(self.device)
            self.reward_count += reward_count.to(self.device)
            self.initial_state_count += initial_state_count.to(self.device)

        # Update initial states
        self.initial_states.extend(initial_states)

        # Vectorized update of trajectory data
        states, actions, rewards, next_states = zip(*[step for traj in trajectories for step in traj])
        self.states.extend(states)
        self.actions.extend(actions)
        self.rewards.extend(rewards)
        self.next_states.extend(next_states)

        self.total_steps += steps

        # Trim data if necessary
        if self.total_steps > self.max_size:
            excess = self.total_steps - self.max_size
            for _ in range(excess):
                self.states.popleft()
                self.actions.popleft()
                self.rewards.popleft()
                self.next_states.popleft()
            self.total_steps = self.max_size

    def get_data(self):
        return (
            self.transition_counts,
            self.reward_total,
            self.reward_count,
            self.initial_state_count,
            list(self.initial_states),
            torch.tensor(list(self.states), device=self.device),
            torch.tensor(list(self.actions), device=self.device),
            torch.tensor(list(self.rewards), device=self.device),
            torch.tensor(list(self.next_states), device=self.device)
        )

    def process_trajectories(self):
        return (
            torch.tensor(list(self.states), device=self.device),
            torch.tensor(list(self.actions), device=self.device),
            torch.tensor(list(self.rewards), device=self.device),
            torch.tensor(list(self.next_states), device=self.device)
        )


class CustomAlgorithm:
    def __init__(self, env, method='tabular', batch_size=200, epochs_per_batch=10, lr=0.01, max_accumulated_steps=10000):
        self.env = env
        self.method = method
        self.batch_size = batch_size
        self.epochs_per_batch = epochs_per_batch
        self.lr = lr
        self.max_accumulated_steps = max_accumulated_steps

        self.total_states = get_num_states(env)
        self.total_actions = env.action_space.n
        self.gamma = 0.99
        self.regularizer = 0.01

        self.theta = torch.nn.Parameter(torch.ones(self.total_states, self.total_actions, device=device) / self.total_actions)
        self.optimizer = optim.Adam([self.theta], lr=self.lr)

        self.accumulated_data = VectorizedAccumulatedData(max_size=self.max_accumulated_steps, device=device)

    def learn(self, total_timesteps):
        steps_taken = 0
        while steps_taken < total_timesteps:
            # Data collection phase
            with torch.no_grad():
                policy = torch.nn.functional.softmax(self.theta, dim=1)
                new_data = sample_steps(self.env, policy, self.batch_size, max_steps_per_episode=200)
            self.accumulated_data.update(*new_data)
            steps_taken += new_data[-1]

            # Get accumulated data
            transition_counts, reward_total, reward_count, initial_state_count, initial_states, states, actions, rewards_sample, next_states = self.accumulated_data.get_data()

            v = normalise_initial(initial_state_count.float())
            R = torch.div(reward_total, reward_count.where(reward_count != 0, torch.tensor(1.0, device=device)))
            P = normalize_slices(transition_counts.float())

            # Policy optimization phase
            for _ in range(self.epochs_per_batch):
                self.optimizer.zero_grad()
                policy = torch.nn.functional.softmax(self.theta, dim=1)

                if self.method == 'counting':
                    J = compute_J_counting(self.env, policy, v, R, P, self.gamma)
                elif self.method == 'tabular':
                    J = tabular_feature_map(self.total_states, self.total_actions, self.regularizer, policy,
                                            initial_states, states, actions, next_states, rewards_sample, self.gamma)
                else:
                    raise ValueError("method must be either 'counting' or 'tabular'")

                loss = -J
                loss.backward()
                self.optimizer.step()

        return self

    def predict(self, observation, state=None, deterministic=False):
        with torch.no_grad():
            policy = torch.nn.functional.softmax(self.theta, dim=1)
            if deterministic:
                action = policy[observation].argmax().item()
            else:
                action = torch.multinomial(policy[observation], 1).item()
        return action, state

def custom_algorithm(env, method='tabular', **kwargs):
    return CustomAlgorithm(env, method=method, **kwargs)

Using device: cpu


  and should_run_async(code)


REINFORCE Algorithm

In [None]:
import numpy as np
import gymnasium as gym

class REINFORCEWrapper:
    def __init__(self, env, learning_rate=0.01, gamma=0.99):
        self.env = env
        self.lr = learning_rate
        self.gamma = gamma
        self.n_actions = env.action_space.n
        self.N = env.N  # Assuming the environment has an attribute N for grid size
        self.n_states = self.N * self.N
        self.theta = np.zeros((self.n_states, self.n_actions))  # Policy parameters

    def softmax(self, x):
        e_x = np.exp(x - np.max(x))  # Subtract max for numerical stability
        return e_x / e_x.sum(axis=-1, keepdims=True)

    def choose_action(self, state):
        state_idx = state
        action_probs = self.softmax(self.theta[state_idx])
        return np.random.choice(self.n_actions, p=action_probs)

    def update_policy(self, episode):
        G = 0
        for t in reversed(range(len(episode))):
            state, action, reward = episode[t]
            state_idx = state
            G = self.gamma * G + reward

            action_probs = self.softmax(self.theta[state_idx])
            grad = np.zeros_like(self.theta[state_idx])
            grad[action] = 1 - action_probs[action]
            grad -= action_probs

            self.theta[state_idx] += self.lr * G * grad

    def learn(self, total_timesteps):
        steps_taken = 0
        while steps_taken < total_timesteps:
            state, _ = self.env.reset()
            episode = []
            done = False
            while not done:
                action = self.choose_action(state)
                next_state, reward, done, _, _ = self.env.step(action)
                episode.append((state, action, reward))
                state = next_state
                steps_taken += 1
                if steps_taken >= total_timesteps:
                    break
            self.update_policy(episode)
        return self

    def predict(self, observation, state=None, deterministic=False):
        if deterministic:
            action = np.argmax(self.softmax(self.theta[observation]))
        else:
            action = self.choose_action(observation)
        return action, state

def reinforce_algorithm(env, learning_rate=0.01, gamma=0.99):
    return REINFORCEWrapper(env, learning_rate, gamma)

In [None]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3 import A2C, PPO
from stable_baselines3.common.env_util import make_vec_env

def evaluate_policy(env, model, n_eval_episodes=100):
    total_reward = 0
    for _ in range(n_eval_episodes):
        state, _ = env.reset()
        done = False
        while not done:
            action, _ = model.predict(state, deterministic=False)
            state, reward, done, _, _ = env.step(action)
            total_reward += reward
    return total_reward / n_eval_episodes

def run_algorithm(algo_class, env, total_steps, eval_interval, **algo_kwargs):
    rewards = []
    steps = []

    for step in range(0, total_steps + 1, eval_interval):
        # Create a new instance of the algorithm for each evaluation
        #print(step)
        if algo_class in [A2C, PPO]:
            algo = algo_class('MlpPolicy', env, **algo_kwargs)
        else:
            algo = algo_class(env, **algo_kwargs)

        # Train the algorithm
        algo.learn(total_timesteps=step)

        # Evaluate the current policy
        avg_reward = evaluate_policy(env, algo)
        print(f"{step}, {algo_class.__name__}, {avg_reward}")
        rewards.append(avg_reward)
        steps.append(step)

    return steps, rewards

def run_multiple_times(algo_class, env, total_steps, eval_interval, num_runs=3, **algo_kwargs):
    all_rewards = []
    for _ in range(num_runs):
        _, rewards = run_algorithm(algo_class, env, total_steps, eval_interval, **algo_kwargs)
        all_rewards.append(rewards)

    avg_rewards = np.mean(all_rewards, axis=0)
    std_rewards = np.std(all_rewards, axis=0)
    return avg_rewards, std_rewards

# Set up the environment
N = 7  # Set your desired N value
env = gym.make('DeepSea-v0')

# Set up algorithms with best hyperparameters
total_steps = int(N*10000)
eval_interval = int(total_steps // 56)

# Run comparisons
custom_tabular_rewards, custom_tabular_std = run_multiple_times(custom_algorithm, env, total_steps, eval_interval, method='tabular', batch_size=1250, epochs_per_batch=30, lr=0.01)
custom_counting_rewards, custom_counting_std = run_multiple_times(custom_algorithm, env, total_steps, eval_interval, method='counting', batch_size=1250, epochs_per_batch=30, lr=0.01)
reinforce_rewards, reinforce_std = run_multiple_times(reinforce_algorithm, env, total_steps, eval_interval, learning_rate=0.01, gamma=0.99)
a2c_rewards, a2c_std = run_multiple_times(A2C, env, total_steps, eval_interval)
ppo_rewards, ppo_std = run_multiple_times(PPO, env, total_steps, eval_interval)

# Plot results
steps = list(range(0, total_steps + 1, eval_interval))
plt.figure(figsize=(12, 8))

plt.plot(steps, custom_tabular_rewards, label='Tabular Algorithm')
plt.fill_between(steps, custom_tabular_rewards - custom_tabular_std, custom_tabular_rewards + custom_tabular_std, alpha=0.3)

plt.plot(steps, custom_counting_rewards, label='Counting Algorithm')
plt.fill_between(steps, custom_counting_rewards - custom_counting_std, custom_counting_rewards + custom_counting_std, alpha=0.3)

plt.plot(steps, reinforce_rewards, label='REINFORCE')
plt.fill_between(steps, reinforce_rewards - reinforce_std, reinforce_rewards + reinforce_std, alpha=0.3)

plt.plot(steps, a2c_rewards, label='A2C')
plt.fill_between(steps, a2c_rewards - a2c_std, a2c_rewards + a2c_std, alpha=0.3)

plt.plot(steps, ppo_rewards, label='PPO')
plt.fill_between(steps, ppo_rewards - ppo_std, ppo_rewards + ppo_std, alpha=0.3)

plt.xlabel('Number of Steps')
plt.ylabel('Average Reward')
plt.title(f'Algorithm Comparison for Deep Sea Exploration (N={N}, Average of 3 Runs)')
plt.legend()
plt.grid(True)
plt.show()

0
0, custom_algorithm, 0.005642857142857103
1250
1250, custom_algorithm, 0.0955714285714285
2500
2500, custom_algorithm, 0.2649285714285702
3750
3750, custom_algorithm, 0.5931571428571462
5000
5000, custom_algorithm, 0.7125285714285724
6250
6250, custom_algorithm, 0.8819571428571394
7500
7500, custom_algorithm, 0.8719285714285681
8750
8750, custom_algorithm, 0.9117428571428519
10000
10000, custom_algorithm, 0.8520714285714257
11250
11250, custom_algorithm, 0.8918571428571387
12500
12500, custom_algorithm, 0.9515571428571367
13750
13750, custom_algorithm, 0.9615857142857083
15000
15000, custom_algorithm, 0.9616571428571372
16250
16250, custom_algorithm, 0.9416999999999948
17500
17500, custom_algorithm, 0.971542857142851
18750
18750, custom_algorithm, 0.9615571428571372
20000
20000, custom_algorithm, 0.9914285714285649
21250
21250, custom_algorithm, 0.9614999999999939
22500
22500, custom_algorithm, 0.9914285714285649
23750
23750, custom_algorithm, 0.9616428571428512
25000
25000, custom_a

  logger.warn(


2500, reinforce_algorithm, 0.015471428571428723
3750
3750, reinforce_algorithm, 0.005542857142857091
5000
5000, reinforce_algorithm, 0.005514285714285679
6250
6250, reinforce_algorithm, -0.004142857142857156
7500
7500, reinforce_algorithm, 0.035614285714285925
8750
8750, reinforce_algorithm, 0.015742857142857284
10000
10000, reinforce_algorithm, 0.015771428571428672
11250
11250, reinforce_algorithm, 0.045542857142857066
12500
12500, reinforce_algorithm, -0.004514285714285731
13750
13750, reinforce_algorithm, 0.005185714285714265
15000
15000, reinforce_algorithm, 0.035257142857142994
16250
16250, reinforce_algorithm, 0.005657142857142812
17500
17500, reinforce_algorithm, -0.0040428571428571545
18750
18750, reinforce_algorithm, 0.025414285714285875
20000
20000, reinforce_algorithm, 0.05557142857142849
21250
21250, reinforce_algorithm, 0.05511428571428561
22500
22500, reinforce_algorithm, 0.045414285714285796
23750
23750, reinforce_algorithm, 0.03537142857142868
25000
25000, reinforce_alg