## Testing BARTTSAgent with a Simulation

In this notebook we define a simple scenario (LinearScenario) and a simulation function. The simulation runs a loop over several rounds (draws) in which, at each round, the scenario generates covariates and rewards. Each agent (here a BARTTSAgent) selects an arm based on its current state, the cumulative regret is updated, and the agent’s state is updated with the observed reward.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from bart_playground.bandit.basic_bart import BARTTSAgent
from bart_playground.bandit.bcf_agent import BCFAgent
from bart_playground.bandit.sim_util import *

### Running the Simulation
 
In this example, we set up a scenario with a given number of arms and features, create a BARTTSAgent (or BCFAgent), run the simulation for a fixed number of draws, and then plot the cumulative regret over time.

In [2]:
# Simulation parameters
n_arms = 3
n_features = 4
n_draws = 500

# Create a scenario instance (LinearScenario in this example)
scenario = LinearOffsetScenario(P=n_features, K=n_arms, sigma2=1.0)

# Create a list of agents.
# Here only one agent is created, but we can create multiple agents if needed.
agent = BCFAgent(n_arms=n_arms, n_features=n_features, nskip=100, ndpost=10)
agents = [agent]

In [3]:
# Run the simulation.
# Define a helper that calls simulate() once and stores the result.
def run_simulation():
    global simulation_result
    simulation_result = simulate(scenario, agents, n_draws=n_draws)
    return simulation_result

_ = run_simulation()

Simulating: 100%|██████████| 500/500 [01:06<00:00,  7.52it/s]


In [4]:
# %prun -s cumtime -D profile_bandit.prof -q run_simulation()
# !gprof2dot -f pstats profile_bandit.prof -o profile_bandit.dot
# !dot -Tpng profile_bandit.dot -o profile_bandit.png

### Plotting the Results

We plot the cumulative regret over simulation draws.

In [5]:
cum_regrets, time_agent = simulation_result

In [None]:

plt.figure(figsize=(10, 6))
for i, agent in enumerate(agents):
    plt.plot(cum_regrets[:, i], label=f"Agent {i+1}")
plt.xlabel("Draw")
plt.ylabel("Cumulative Regret")
plt.title("Cumulative Regret over Time")
plt.legend()
plt.show()

print("Agent computation times (seconds):", time_agent)