In [1]:
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfigurationChannel
from mlagents_envs.base_env import ActionTuple

from debug_side_channel import DebugSideChannel
from sac_agent import SACAgent, ReplayBuffer

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import sys
import wandb
import numpy as np
import yaml
import pandas as pd
import torch.optim as optim

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
with open('config/config.yaml', 'r') as file:
    config = yaml.safe_load(file)
agent_registry = []
agent_registry.append(SACAgent(
                        observation_size=config['observation_size'],
                        action_dim=config['action_dim'], 
                        hidden_size=config['hidden_size'],
                        learning_rate=config['learning_rate']))
gamma = config['gamma']
tau = config['tau']

In [3]:
if config['wandb_log']:
    wandb.init(
        project="visibility-game",
    )

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mr-marr747[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
#repeat = 15
root = 'tensors'
#g_states = torch.cat((torch.load(f'{root}/states.pt'), torch.tensor([6.5000, 0.5000, 1.5000]).repeat(repeat, 1)), dim=0)
#g_actions = torch.cat((torch.load(f'{root}/actions.pt'), torch.tensor([0,1,0,0,0]).repeat(repeat, 1)))
#g_rewards = torch.cat((torch.load(f'{root}/rewards.pt') * 10, torch.tensor(10.).repeat(repeat)))
#g_next_states = torch.cat((torch.load(f'{root}/next_states.pt'), torch.tensor([7.5000, 0.5000, 0.5000]).repeat(repeat, 1)), dim=0)
#g_dones = torch.cat((torch.load(f'{root}/dones.pt'), torch.tensor(1.).repeat(repeat)))

In [5]:
g_states = torch.load(f'{root}/states.pt')
g_actions = torch.load(f'{root}/actions.pt')
g_rewards = torch.load(f'{root}/rewards.pt')
g_next_states = torch.load(f'{root}/next_states.pt')
g_dones = torch.load(f'{root}/dones.pt')

In [6]:
g_rewards[g_rewards == 1000]

tensor([1000, 1000, 1000, 1000])

In [7]:
reward_idx = [g_rewards == 1000]
no_rewards = g_actions[reward_idx].shape[0]
print(g_states[reward_idx])
print(g_actions[reward_idx])
print(g_rewards[reward_idx])
print(g_next_states[reward_idx])
print(g_dones[reward_idx])


tensor([[6.5000, 0.5000, 1.5000],
        [6.5000, 0.5000, 1.5000],
        [6.5000, 0.5000, 1.5000],
        [6.5000, 0.5000, 1.5000]])
tensor([[0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0]])
tensor([1000, 1000, 1000, 1000])
tensor([[6.5000, 0.5000, 2.5000],
        [6.5000, 0.5000, 2.5000],
        [6.5000, 0.5000, 2.5000],
        [6.5000, 0.5000, 2.5000]])
tensor([1., 1., 1., 1.])


In [8]:
torch.randperm(g_states.size(0))[:10-no_rewards]

tensor([ 710, 1001,  625,  396,  122,  826])

In [9]:
l = g_states.shape[0]
total_target = torch.tensor([0])
total_q1 = torch.tensor([0])
total_q2 = torch.tensor([0])
total_actor_loss = torch.tensor([0]) 
for epoch in range(100000000):
    agent = agent_registry[0]
    sample_ind = torch.randperm(g_states.size(0))
    states = g_states[sample_ind[:10]]
    actions = g_actions[sample_ind[:10]]
    rewards = g_rewards[sample_ind[:10]] / 1000
    #if no_rewards < 10:
        #states = torch.cat( (g_states[reward_idx][:10], g_states[torch.randperm(g_states.size(0))[:10-no_rewards]]), dim=0)
        #actions = torch.cat( (g_actions[reward_idx][:10], g_actions[torch.randperm(g_actions.size(0))[:10-no_rewards]]), dim=0)
        #rewards = torch.cat( (g_rewards[reward_idx][:10], g_rewards[torch.randperm(g_rewards.size(0))[:10-no_rewards]]), dim=0)
        #actions = g_actions[reward_idx][:10] [torch.randperm(g_states.size(0))[:10]]
        #rewards = g_rewards[reward_idx][:10] [torch.randperm(g_states.size(0))[:10]]
    next_states = g_next_states[sample_ind[:10]] 
    dones = g_dones[sample_ind[:10]]

    with torch.no_grad():
        next_actions = agent.actor.forward(next_states)
        next_q1 = agent.target_critic1(next_states)
        next_q2 = agent.target_critic2(next_states)
        state_values = (
            next_actions * (torch.min(next_q1, next_q2))
        ).sum(dim=1)
        target_q = rewards + (1 - dones) * gamma * state_values 

    idx = actions.argmax(dim=1)
    q1 = agent.critic1(states)#.gather(1, actions)
    q1 = torch.gather(q1, dim=1, index=idx.unsqueeze(-1)).squeeze(-1) 
    q2 = agent.critic2(states)#.gather(1, actions)
    q2 = torch.gather(q2, dim=1, index=idx.unsqueeze(-1)).squeeze(-1)

    total_q2 = total_q2 + q2
    total_q1 = total_q1 + q1
    total_target = total_target + target_q 
    if (epoch + 1) % 100 == 0:
        critic1_loss = F.mse_loss(total_q1, total_target)
        agent.critic1_optimizer.zero_grad()
        critic1_loss.backward(retain_graph=True)
        if config['wandb_log']:
            wandb.log({"critic1_loss": critic1_loss})
        agent.critic1_optimizer.step()

        critic2_loss = F.mse_loss(total_q2, total_target)
        agent.critic2_optimizer.zero_grad()
        critic2_loss.backward(retain_graph=True)
        if config['wandb_log']:
            wandb.log({"critic2_loss": critic2_loss})
        agent.critic2_optimizer.step()

        total_target = torch.tensor([0])
        total_q1 = torch.tensor([0])
        total_q2 = torch.tensor([0])

        for target_param, param in zip(agent.target_critic1.parameters(), agent.critic1.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

        for target_param, param in zip(agent.target_critic2.parameters(), agent.critic2.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
        

    new_actions = agent.actor.forward(states)
    min_q = torch.min(agent.critic1(states), agent.critic2(states))
    actor_loss = torch.tensor(-1) * (new_actions * min_q).sum(dim=1).mean() 
    total_actor_loss = total_actor_loss + actor_loss
    if (epoch + 1) % 100 == 0:
        agent.actor_optimizer.zero_grad()
        actor_loss.backward(retain_graph=True)
        if config['wandb_log']:
            wandb.log({"actor_loss": actor_loss})
        agent.actor_optimizer.step()
        actor_loss = torch.tensor([0])

KeyboardInterrupt: 

In [None]:
agent = agent_registry[0]
positions = torch.tensor([ 
                            #[5.5, .5, 1.5],
                            [6.5, .5, 1.5]
                            ])
print(agent.critic1.forward(positions))
print(agent.critic2.forward(positions))
print(agent.actor.forward(positions))
# 1, 3

tensor([[ 0.0041, -0.0168, -0.0362,  0.0170,  0.0262]],
       grad_fn=<AddmmBackward>)
tensor([[ 0.0057, -0.0110, -0.0119,  0.0455,  0.0137]],
       grad_fn=<AddmmBackward>)
tensor([[1.4343e-03, 6.6882e-04, 9.8924e-01, 8.3080e-03, 3.4550e-04]],
       grad_fn=<SoftmaxBackward>)


In [None]:
low_q = agent.critic1(torch.tensor([[5.5, 0.5, 0.5]])).squeeze()
print(low_q)
print(agent.actor.forward(torch.tensor([[5.5, 0.5, 0.5]])))
high_q = agent.critic1(torch.tensor([[6.5, 0.5, 1.5]])).squeeze()
print(high_q)
print(agent.actor.forward(torch.tensor([[6.5, 0.5, 1.5]])))

tensor([ 0.0551, -0.1621, -0.0892, -0.1390, -0.0678],
       grad_fn=<SqueezeBackward0>)
tensor([[0.1683, 0.1641, 0.2377, 0.2750, 0.1549]], grad_fn=<SoftmaxBackward>)
tensor([ 0.1485, -0.0750, -0.1528, -0.0568, -0.0177],
       grad_fn=<SqueezeBackward0>)
tensor([[0.1549, 0.1727, 0.2399, 0.2918, 0.1407]], grad_fn=<SoftmaxBackward>)
