In [1]:
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfigurationChannel
from mlagents_envs.base_env import ActionTuple

from debug_side_channel import DebugSideChannel
from sac_agent import SACAgent, ReplayBuffer

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import sys
import wandb
import numpy as np
import yaml
import pandas as pd
import torch.optim as optim

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
with open('config/config.yaml', 'r') as file:
    config = yaml.safe_load(file)
agent_registry = []
agent_registry.append(SACAgent(
                        observation_size=config['observation_size'],
                        action_dim=config['action_dim'], 
                        hidden_size=config['hidden_size'],
                        learning_rate=config['learning_rate']))
gamma = config['gamma']

In [3]:
if config['wandb_log']:
    wandb.init(
        project="visibility-game",
    )

In [4]:
repeat = 256
root = 'tensors'
g_states = torch.cat((torch.load(f'{root}/states.pt'), torch.tensor([6.5000, 0.5000, 1.5000]).repeat(repeat, 1)), dim=0)
g_actions = torch.cat((torch.load(f'{root}/actions.pt'), torch.tensor([0,1,0,0,0]).repeat(repeat, 1)))
g_rewards = torch.cat((torch.load(f'{root}/rewards.pt') * 10, torch.tensor(10.).repeat(repeat)))
g_next_states = torch.cat((torch.load(f'{root}/next_states.pt'), torch.tensor([7.5000, 0.5000, 0.5000]).repeat(repeat, 1)), dim=0)
g_dones = torch.cat((torch.load(f'{root}/dones.pt'), torch.tensor(1.).repeat(repeat)))

In [5]:
torch.autograd.set_detect_anomaly(True)
l = g_states.shape[0]
while True:
    agent = agent_registry[0]
    states = g_states[torch.randperm(g_states.size(0))[:10]]
    actions = g_actions[torch.randperm(g_states.size(0))[:10]]
    rewards = g_rewards[torch.randperm(g_states.size(0))[:10]]
    next_states = g_next_states[torch.randperm(g_states.size(0))[:10]] 
    dones = g_dones[torch.randperm(g_states.size(0))[:10]]

    with torch.no_grad():
        next_actions = agent.actor.forward(next_states)
        next_q1 = agent.target_critic1(next_states)
        next_q2 = agent.target_critic2(next_states)
        state_values = (
            next_actions * (torch.min(next_q1, next_q2))
        ).sum(dim=1)
        target_q = rewards + (1 - dones) * gamma * state_values 

    q1 = agent.critic1(states).gather(1, actions)
    print(f'q1 {q1}')
    q2 = agent.critic2(states).squeeze()
        
    critic1_loss = F.mse_loss(q1, target_q)
    agent.critic1_optimizer.zero_grad()
    critic1_loss.backward(retain_graph=True)
    if config['wandb_log']:
        wandb.log({"critic1_loss": critic1_loss})
    agent.critic1_optimizer.step()

    critic2_loss = F.mse_loss(q2, target_q)
    agent.critic2_optimizer.zero_grad()
    critic2_loss.backward(retain_graph=True)
    if config['wandb_log']:
        wandb.log({"critic2_loss": critic2_loss})
    agent.critic2_optimizer.step()

    tau = .005
    for target_param, param in zip(agent.target_critic1.parameters(), agent.critic1.parameters()):
        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

    for target_param, param in zip(agent.target_critic2.parameters(), agent.critic2.parameters()):
        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

    new_actions = agent.actor.select_action_tanh(states)
    min_q = torch.min(agent.critic1(states), agent.critic2(states, new_actions))
    actor_loss = torch.tensor(-1) * (min_q).mean()
    agent.actor_optimizer.zero_grad()
    actor_loss.backward(retain_graph=True)
    if config['wandb_log']:
        wandb.log({"actor_loss": actor_loss})
    agent.actor_optimizer.step()

q1 tensor([[-0.1638,  0.3103, -0.3781, -0.3616, -0.3240],
        [-0.1060,  0.2313, -0.3420, -0.3466, -0.2845],
        [-0.1358,  0.2337, -0.2837, -0.3048, -0.2507],
        [-0.1638,  0.3103, -0.3781, -0.3616, -0.3240],
        [-0.1802,  0.2449, -0.2340, -0.3019, -0.2059],
        [-0.1358,  0.2337, -0.2837, -0.3048, -0.2507],
        [-0.2459,  0.2593, -0.1794, -0.3190, -0.1615],
        [-0.1802,  0.2449, -0.2340, -0.3019, -0.2059],
        [-0.1358,  0.2337, -0.2837, -0.3048, -0.2507],
        [-0.1802,  0.2449, -0.2340, -0.3019, -0.2059]],
       grad_fn=<SqueezeBackward0>)




RuntimeError: The size of tensor a (5) must match the size of tensor b (10) at non-singleton dimension 1

In [None]:
agent_registry[0].critic1.draw_graph()

{(4.5, 0.5, 1.5): [[6.627501010894775],
  [6.520047664642334],
  [5.859687328338623],
  [5.944876670837402],
  [6.512833595275879]],
 (5.5, 0.5, 1.5): [[8.0948486328125],
  [8.356216430664062],
  [7.325829982757568],
  [7.045655727386475],
  [7.409486293792725]],
 (6.5, 0.5, 1.5): [[9.506819725036621],
  [10.145913124084473],
  [8.991620063781738],
  [8.38970947265625],
  [8.388338088989258]],
 (4.5, 0.5, 0.5): [[5.247955322265625],
  [5.565765380859375],
  [5.149465084075928],
  [5.433624744415283],
  [6.0807671546936035]],
 (5.5, 0.5, 0.5): [[6.540751934051514],
  [7.0263848304748535],
  [6.499908924102783],
  [6.384011268615723],
  [6.883880615234375]],
 (6.5, 0.5, 0.5): [[7.822787761688232],
  [8.406075477600098],
  [7.906334400177002],
  [7.533953666687012],
  [7.776287078857422]],
 (6.5, 0.5, 2.5): [[10.530986785888672],
  [10.386754035949707],
  [9.546075820922852],
  [8.974699020385742],
  [8.884608268737793]],
 (6.5, 0.5, 3.5): [[11.00023078918457],
  [10.61318302154541],
  [9

In [None]:
low_q = agent.critic1(torch.tensor([[5.5, 0.5, 0.5]]), torch.tensor([[1.]])).squeeze()
print(low_q)
high_q = agent.critic1(torch.tensor([[6.5, 0.5, 1.5]]), torch.tensor([[1.]])).squeeze()
print(high_q)

tensor(7.9241, grad_fn=<SqueezeBackward0>)
tensor(9.2092, grad_fn=<SqueezeBackward0>)
