# REINFORCE with Baseline (Separate Network)

## Setup

In [None]:
import torch
from network.PolicyNetwork import PolicyNetwork
from network.StateValueNetwork import StateValueNetwork
from train import train, test
from agents_tictactoe.ReinforceBaselineAgent import ReinforceBaselineAgent
from torch.utils.tensorboard import SummaryWriter
from env.TicTacToeEnvironment import TicTacToeEnvironment

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

env = TicTacToeEnvironment()
num_cells = env.board_size[0] * env.board_size[1]
num_hidden_units = 64
num_layers = 1
dropout_probability = 0.5

## TRAIN
The TRAINING procedure has already finished. Only run the train part when you want train again.

### Train with Random Sampling

In [None]:
policy_net = PolicyNetwork(num_cells, num_hidden_units, num_layers, dropout_probability, num_cells).to(device)
value_net = StateValueNetwork(num_cells, num_hidden_units, num_layers, dropout_probability).to(device)

agent_a = ReinforceBaselineAgent(env, policy_net, value_net, lr=0.002, weight_decay=0.01)
agent_b = ReinforceBaselineAgent(env)

writer = SummaryWriter('runs/tictactoe_8k/reinforce_with_baseline/random')
train(env, agent_a, agent_b, episodes=80000, log_interval=1000, writer=writer)
test(env, agent_a, agent_b)
torch.save(policy_net, 'models/tictactoe/reinforce_with_baseline/policy_random_8k_0.002.pth')
torch.save(value_net, 'models/tictactoe/reinforce_with_baseline/value_random_8k_0.002.pth')

### Train with Dual Agents

In [None]:
policy_net_a = PolicyNetwork(num_cells, num_hidden_units, num_layers, dropout_probability, num_cells).to(device)
value_net_a = StateValueNetwork(num_cells, num_hidden_units, num_layers, dropout_probability).to(device)

policy_net_b = PolicyNetwork(num_cells, num_hidden_units, num_layers, dropout_probability, num_cells).to(device)
value_net_b = StateValueNetwork(num_cells, num_hidden_units, num_layers, dropout_probability).to(device)

agent_a = ReinforceBaselineAgent(env, policy_net_a, value_net_a, lr=0.002, weight_decay=0.01)
agent_b = ReinforceBaselineAgent(env, policy_net_b, value_net_b, lr=0.002, weight_decay=0.01)

writer = SummaryWriter('runs/tictactoe_8k/reinforce_with_baseline/agents')
train(env, agent_a, agent_b, episodes=80000, log_interval=1000, writer=writer)
test(env, agent_a, agent_b)
torch.save(policy_net_a, 'models/tictactoe/reinforce_with_baseline/policy_agents_8k_0.002.pth')
torch.save(value_net_a, 'models/tictactoe/reinforce_with_baseline/value_agents_8k_0.002.pth')

# TEST

In [None]:
def train_with_random(policy_net, value_net, draw_board: bool = False, episodes: int = 10000):
    env = TicTacToeEnvironment()
    agent_a = ReinforceBaselineAgent(env, policy_net, value_net)
    agent_b = ReinforceBaselineAgent(env)
    test(env, agent_a, agent_b, draw_board=draw_board, episodes=episodes)


def train_with_agents(policy_net_1, value_net_1, policy_net_2, value_net_2, draw_board: bool = True):
    env = TicTacToeEnvironment()
    agent_a = ReinforceBaselineAgent(env, policy_net_1, value_net_1)
    agent_b = ReinforceBaselineAgent(env, policy_net_2, value_net_2)
    test(env, agent_a, agent_b, draw_board=draw_board, episodes=10)

# Load the networks
policy_net_random = torch.load('models/tictactoe/reinforce_with_baseline/policy_random_40k_0.001.pth')
value_net_random = torch.load('models/tictactoe/reinforce_with_baseline/value_random_40k_0.001.pth')

# Test with random sampling
train_with_random(policy_net_random, value_net_random)

# Load the networks
policy_net_agents = torch.load('models/tictactoe/reinforce_with_baseline/policy_agents_40k_0.001.pth')
value_net_agents = torch.load('models/tictactoe/reinforce_with_baseline/value_agents_40k_0.001.pth')

# Load the networks
policy_net_agents = torch.load('models/tictactoe/reinforce_with_baseline/policy_agents_40k_0.001.pth')
value_net_agents = torch.load('models/tictactoe/reinforce_with_baseline/value_agents_40k_0.001.pth')

# Test with random sampling
train_with_random(policy_net_agents, value_net_agents)

# Random Sampling vs. Dual Agents
train_with_agents(policy_net_random, value_net_random, policy_net_agents, value_net_agents)