In [10]:
# import os
# os.listdir('../input/rl-project')
# import sys
# sys.path.insert(0,'../input/rl-project/')

In [11]:
import torch
from torch.distributions import Categorical
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import time
import math
import cv2

from IPython.display import clear_output

from torch.utils.tensorboard import SummaryWriter
import socket
from datetime import datetime
import os

from agents import Agent
from environment import SimulationEnvironment0
from replay_buffers import *
from utils import *

import copy
experiment_name='discrete-ACER-PPO'

In [12]:
torch.autograd.set_detect_anomaly(False)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x1761ea04490>

In [14]:
seed = 0

EXPERIMENTS = [#{"entropy": 3e-4, 'PPO':4, 'replay_ratio':0, 'training_steps':2**17},
               {"entropy": 3e-4, 'PPO':4, 'replay_ratio':1, 'training_steps':2**17},
               ]


# simulation
num_simulations = 128
num_blackholes = 1

# agent
hidden_size = 512
simlog_res = 255
use_symlog = True
simlog_half_res = simlog_res//2
simlog_max_range = 1
actions_res = 5
levels=2
input_type = 'complete'

lr = 3e-4
lr_actor = 3e-5


# training
#training_steps = 2**18
#epochs=8
lamb = 0.8
gamma = 0.98
smoothing = 1e-2
eps = 0.05 # for PPO update
seg_len = 2**5
h_samples = 128


# replay buffers
buffer = Replay_Buffer
#size = 2**16
num_steps = 1024
batch_size = 2**10
replay_batch_size = num_simulations

plot = False

validate_every = 2**9

bin_values = (torch.arange(simlog_res)-simlog_half_res).cuda()/simlog_half_res*simlog_max_range
bin_values = bin_values.sign()*(bin_values.abs().exp()-1)

dec_x, dec_y = torch.meshgrid(torch.arange(actions_res)/(actions_res-1)*2-1, torch.arange(actions_res)/(actions_res-1)*2-1)
dec_x, dec_y = dec_x.flatten().cuda(), dec_y.flatten().cuda()

metric_idx = torch.pow(2,torch.arange(15))-1

fig, ax = plt.subplots(figsize=(10,10))



for experiment in EXPERIMENTS:

    torch.manual_seed(seed)

    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
    log_dir = os.path.join(
        "runs",experiment_name, current_time + "_" + socket.gethostname() 
    )

    tb_writer = SummaryWriter(log_dir)

    sim = SimulationEnvironment0(num_simulations=num_simulations,
                            num_blackholes=num_blackholes, 
                            force_constant=0.002, 
                            velocity_scale=0.01,
                            goal_threshold=0.05,
                            max_steps=250,
                            device='cuda')

    action_dim = actions_res**2
    if use_symlog:
        actor = Agent((num_blackholes+2)*2, hidden_size, levels, input_type, critic=False, action_dimension=action_dim).cuda()
        V = Agent((num_blackholes+2)*2, hidden_size, levels, input_type, actor=False, value_dimension=simlog_res).cuda()
        Q = Agent(hidden_size + 2, hidden_size, levels, input_type='base', actor=False, value_dimension=simlog_res).cuda()
    else:
        actor = Agent((num_blackholes+2)*2, hidden_size, levels, input_type, critic=False, action_dimension=action_dim).cuda()
        V = Agent((num_blackholes+2)*2, hidden_size, levels, input_type, actor=False, value_dimension=1).cuda()
        Q = Agent(hidden_size + 2, hidden_size, levels, input_type='base', actor=False, value_dimension=1).cuda()

    optim_actor = torch.optim.AdamW(actor.parameters(), lr=lr_actor, weight_decay=1e-3)
    optim_critic = torch.optim.AdamW(list(V.parameters())+list(Q.parameters()), lr=lr, weight_decay=1e-3)
    target_actor = copy.deepcopy(actor)
    target_V = copy.deepcopy(V)
    target_Q = copy.deepcopy(Q)

    replay_buffer = Replay_Buffer_Segments(state_shape=((num_blackholes+2),2), action_shape=(1,), segment_lenght=seg_len, num_simulations=num_simulations, num_steps=num_steps, batch_size=replay_batch_size, device='cuda')



    old_states=None

    R=[]

    from tqdm import tqdm
    pbar = tqdm(range(experiment['training_steps']))

    x,y = torch.meshgrid(torch.arange(100),torch.arange(100))
    pos = torch.stack([x.flatten(), y.flatten()],1)/100
    target_pos = torch.ones_like(pos)*0.25
    bh_pos = torch.ones_like(pos)*0.75

    st=torch.stack([pos,target_pos,bh_pos],1)

    E = []
    plotV = []
    plotVT = []
    plotPolicy = []

    states = sim.get_state()

    for i in pbar:
        t0=time.time()

        # GENERATE EXPERIENCE
        with torch.inference_mode():
            
            states = states.reshape(states.shape[0],-1).cuda()
            actions, _, _ = actor(states)


            action_probs = actions.softmax(-1)
            dist = Categorical(action_probs)
            sampled_action = dist.sample()
            log_prob = dist.log_prob(sampled_action)

            u, v = dec_x[sampled_action], dec_y[sampled_action]
            sampled_action_decoded = torch.stack([u,v],1)



            rewards, new_states, terminals = sim.step(sampled_action_decoded)

            replay_buffer.add_experience(states.reshape(new_states.shape), sampled_action, log_prob.exp(), rewards, terminals)

        tb_writer.add_scalar('Reward',rewards.mean().item(), i)

        if (i+1) % seg_len == 0:

            # ON POLICY UPDATE

            seg_state, seg_actions, seg_action_probs, seg_rewards, seg_terminal = replay_buffer.get_last_segment()

            with torch.inference_mode():

                # get target values for all the states in the segment
                V_t = torch.zeros(num_simulations*seg_len, device='cuda')
                for b_idx in range((seg_len*num_simulations+batch_size-1)//batch_size):

                    b_state = seg_state.reshape((seg_len)*num_simulations,-1)[b_idx*batch_size:(b_idx+1)*batch_size]
                    
                    _, Vs, _= target_V(b_state)
                    V_t[b_idx*batch_size:(b_idx+1)*batch_size] = (torch.softmax(Vs,1)@bin_values[:,None])[:,0]
        
                V_t = V_t.reshape(num_simulations, seg_len)

                # compute GAE and TD lambda returns
                gae = torch.zeros_like(V_t)
                for t in reversed(range(seg_len-1)):

                    d_t = -V_t[:,t] + seg_rewards[:,t] + gamma*V_t[:,t+1]*seg_terminal[:,t].logical_not()

                    gae[:,t] = d_t + gamma*lamb*gae[:,t+1]*seg_terminal[:,t].logical_not()

                tdl = V_t + gae


            #with torch.autograd.set_detect_anomaly(True):

            # run PPO for n epochs
            for ppo_epoch in range(experiment['PPO']):
                for b_idx in range((seg_len*num_simulations+batch_size-1)//batch_size):

                    b_gae = gae[:,:-1].reshape((seg_len-1)*num_simulations)[b_idx*batch_size:(b_idx+1)*batch_size]
                    b_tdl = tdl[:,:-1].reshape((seg_len-1)*num_simulations)[b_idx*batch_size:(b_idx+1)*batch_size]
                    b_state = seg_state[:,:-1].reshape((seg_len-1)*num_simulations,-1)[b_idx*batch_size:(b_idx+1)*batch_size]
                    b_action = seg_actions[:,:-1].reshape((seg_len-1)*num_simulations)[b_idx*batch_size:(b_idx+1)*batch_size]
                    b_prob = seg_action_probs[:,:-1].reshape((seg_len-1)*num_simulations)[b_idx*batch_size:(b_idx+1)*batch_size]

                    actions, _, _ = actor(b_state)
                    _, Vs, hs = V(b_state)
                    _, Qs, _ = Q(torch.concat([hs, dec_x[b_action.long(), None], dec_y[b_action.long(), None]],-1))

                    action_probs = actions.softmax(-1)
                    H = -(action_probs*(action_probs+1e-8).log()).sum(-1)
                    dist = Categorical(action_probs)
                    probs = dist.log_prob(b_action).exp()

                    r = (probs + 1e-5)/(b_prob + 1e-5)
                    L = torch.minimum(b_gae*r, b_gae*r.clip(1-eps,1+eps))

                    actor_error = - L - H*experiment['entropy']#+ 1e-2*(1-variances).square().mean()

                    y = two_hot_encode(b_tdl, simlog_max_range, simlog_res, simlog_half_res, smoothing=smoothing)
                    critic_error_V = torch.nn.functional.cross_entropy(Vs, y, reduction='none')
                    critic_error_Q = torch.nn.functional.cross_entropy(Vs+Qs, y, reduction='none')
                    critic_error = (critic_error_V + critic_error_Q) / 2
                    error = actor_error + critic_error

                    assert not torch.any(torch.isnan(actor_error))
                    assert not torch.any(torch.isnan(critic_error))
                    assert not torch.any(torch.isinf(actor_error))
                    assert not torch.any(torch.isinf(critic_error))

                    optim_actor.zero_grad()
                    optim_critic.zero_grad()
                    error.mean().backward()
                    for j, param in enumerate(actor.parameters()):
                        assert not param.grad.isnan().any(), print(j, param)
                    optim_actor.step()
                    optim_critic.step()


            #assert (H>0).all()

            tb_writer.add_scalar('critic_error_V',critic_error_V.mean().item(), i)
            tb_writer.add_scalar('critic_error_Q',critic_error_Q.mean().item()-critic_error_V.mean().item(), i)
            tb_writer.add_scalar('hentropy',H.mean().item(), i)

            update_target_model(model=actor, target_model=target_actor, decay=1e-2)
            update_target_model(model=V, target_model=target_V, decay=1e-2)
            update_target_model(model=Q, target_model=target_Q, decay=1e-2)

            # REPLAY EXPERIENCES

            # replay experiences
            for replay_epoch in range(experiment['replay_ratio']):

                seg_state, seg_actions, seg_action_probs, seg_rewards, seg_terminal = replay_buffer.get_batch()

                with torch.inference_mode():

                    # get target values and action probs for all the states in the segment

                    V_t = torch.zeros(replay_batch_size*seg_len, device='cuda')
                    Q_t = torch.zeros(replay_batch_size*seg_len, device='cuda')
                    probs_t = torch.zeros(replay_batch_size*seg_len, device='cuda')

                    for b_idx in range((seg_len*replay_batch_size+batch_size-1)//batch_size):

                        b_state = seg_state.reshape((seg_len)*replay_batch_size,-1)[b_idx*batch_size:(b_idx+1)*batch_size]
                        b_action = seg_actions.reshape((seg_len)*replay_batch_size)[b_idx*batch_size:(b_idx+1)*batch_size]

                        sampled_action_decoded = b_action.tanh()

                        actions, _, _ = target_actor(b_state)
                        _, Vs, hs = target_V(b_state)
                        _, Qs, _ = target_Q(torch.concat([hs, dec_x[b_action.long(), None], dec_y[b_action.long(), None]],-1))

                        action_probs = actions.softmax(-1)
                        H = (action_probs*(1-action_probs+1e-8)).sum(-1)
                        dist = Categorical(action_probs)
                        log_prob = dist.log_prob(b_action)

                        probs_t[b_idx*batch_size:(b_idx+1)*batch_size] = log_prob.exp()
                        V_t[b_idx*batch_size:(b_idx+1)*batch_size] = (torch.softmax(Vs,1)@bin_values[:,None])[:,0]
                        Q_t[b_idx*batch_size:(b_idx+1)*batch_size] = (torch.softmax(Vs+Qs,1)@bin_values[:,None])[:,0]
            
                    V_t = V_t.reshape(replay_batch_size, seg_len)
                    Q_t = Q_t.reshape(replay_batch_size, seg_len)
                    probs_t = probs_t.reshape(replay_batch_size, seg_len)

                    #assert torch.all(probs_t>0)


                    # compute targets (as in RETRACE)

                    Q_ret = torch.zeros_like(V_t)
                    V_target = torch.zeros_like(V_t)

                    Q_ret[:,-1] = Q_t[:,-1]

                    for t in reversed(range(seg_len-1)):

                        ro = probs_t[:,t+1]/seg_action_probs[:,t+1] # ro of t+1
                        ci = lamb * torch.minimum(torch.ones(1, device='cuda'), ro)

                        Q_ret[:,t] = seg_rewards[:,t] + gamma*(ci*(Q_ret[:,t+1] - Q_t[:,t+1]) + V_t[:,t+1])
                        V_target[:, t+1] = ci*(Q_ret[:,t+1] - Q_t[:,t+1]) + V_t[:,t+1]

                    V_target[:, 0] = ci*(Q_ret[:,0] - Q_t[:,0]) + V_t[:,0]


                # run PPO for n epochs
                for ppo_epoch in range(experiment['PPO']):
                    for b_idx in range((seg_len*replay_batch_size+batch_size-1)//batch_size):

                        b_Q_ret = Q_ret[:,:-1].reshape((seg_len-1)*replay_batch_size)[b_idx*batch_size:(b_idx+1)*batch_size]
                        b_V_target = V_target[:,:-1].reshape((seg_len-1)*replay_batch_size)[b_idx*batch_size:(b_idx+1)*batch_size]
                        b_state = seg_state[:,:-1].reshape((seg_len-1)*replay_batch_size,-1)[b_idx*batch_size:(b_idx+1)*batch_size]
                        b_action = seg_actions[:,:-1].reshape((seg_len-1)*replay_batch_size)[b_idx*batch_size:(b_idx+1)*batch_size]
                        #b_action = seg_actions[:,:-1].reshape((seg_len-1)*replay_batch_size,-1)[b_idx*batch_size:(b_idx+1)*batch_size]
                        b_prob = probs_t[:,:-1].reshape((seg_len-1)*replay_batch_size)[b_idx*batch_size:(b_idx+1)*batch_size]


                        actions, _, _ = actor(b_state)
                        _, Vs, hs = V(b_state)
                        _, Qs, _ = Q(torch.concat([hs, dec_x[b_action.long(), None], dec_y[b_action.long(), None]],-1))
                       

                        action_probs = actions.softmax(-1)
                        H = (action_probs*(1-action_probs+1e-8)).sum(-1)
                        dist = Categorical(action_probs)
                        probs = dist.log_prob(b_action).exp()


                        r = (probs + 1e-8 )/(b_prob + 1e-8 )
                        adv = b_Q_ret - b_V_target
                        L = b_prob * torch.minimum(adv*r, adv*r.clip(1-eps,1+eps))

                        actor_error = - L - H*experiment['entropy'] #+ 0.01*variances.square().mean()

                        y_V = two_hot_encode(b_V_target, simlog_max_range, simlog_res, simlog_half_res, smoothing=smoothing)
                        y_Q = two_hot_encode(b_Q_ret, simlog_max_range, simlog_res, simlog_half_res, smoothing=smoothing)
                        critic_error_V = torch.nn.functional.cross_entropy(Vs, y_V, reduction='none')
                        critic_error_Q = torch.nn.functional.cross_entropy(Vs+Qs, y_Q, reduction='none')
                        critic_error = (critic_error_V + critic_error_Q) / 2
                        error = actor_error + critic_error

                        assert not torch.any(torch.isnan(actor_error))
                        assert not torch.any(torch.isnan(critic_error))
                        assert not torch.any(torch.isinf(actor_error))
                        assert not torch.any(torch.isinf(critic_error))

                        optim_actor.zero_grad()
                        optim_critic.zero_grad()
                        error.mean().backward()
                        # for j, param in enumerate(actor.parameters()):
                            # assert not param.grad.isnan().any(), print(j, param)
                        #optim_actor.step()
                        optim_critic.step()

                update_target_model(model=actor, target_model=target_actor, decay=1e-2)
                update_target_model(model=V, target_model=target_V, decay=1e-2)
                update_target_model(model=Q, target_model=target_Q, decay=1e-2)


        states = new_states

        if (i+1) % validate_every == 0:

            Values = []
            A = []
            for b in range((len(st)+batch_size-1)//batch_size):
                stb = st[b*batch_size:(b+1)*batch_size]
                _, v, _ = V(stb.reshape(stb.shape[0],-1).cuda())
                a, _, _ = actor(stb.reshape(stb.shape[0],-1).cuda())
                Values.append(v)
                A.append(a)
            Values = torch.concat(Values,0)
            A = torch.concat(A,0)

            V_t = []
            for b in range((len(st)+batch_size-1)//batch_size):
                stb = st[b*batch_size:(b+1)*batch_size]

                _, v, _ = target_V(stb.reshape(stb.shape[0],-1).cuda())

                V_t.append(v)
            V_t = torch.concat(V_t,0)


            if use_symlog:
                Values = (Values.softmax(1)@bin_values[:,None])[:,0].detach().cpu()
                V_t = (V_t.softmax(1)@bin_values[:,None])[:,0].detach().cpu()
            else:
                V = Values.cpu()
                V_t = V_t.cpu()

            tb_writer.add_image('V', (Values.reshape(1,100,100)/2+0.5), i)
            tb_writer.add_image('V_t', V_t.reshape(1,100,100)/2+0.5, i)

            plotV.append(Values.reshape(1,100,100).detach().cpu())
            plotVT.append(V_t.reshape(1,100,100).detach().cpu())

            pos = st[:,0].reshape(100,100,-1)[::2,::2]
            A = A.reshape(100,100,-1)[::2,::2]
            
            action_probs = A.softmax(-1)
            u = action_probs@dec_x[:,None]
            v = action_probs@dec_y[:,None]
            A = torch.concat([u,v],-1)

            plt.quiver(pos[...,1].flatten(), -pos[...,0].flatten(), A[...,1].tanh().detach().cpu().flatten(), -A[...,0].tanh().detach().cpu().flatten(), color='g',scale=50, headwidth=2)
            ax.axis('off')
            plt.gca().set_aspect('equal')
            plt.subplots_adjust(0,0,1,1,0,0)
            fig.canvas.draw()
            data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
            data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            plt.clf()

            tb_writer.add_image('Policy visualization', np.transpose(data,(2,0,1)) , i)
            plotPolicy.append(np.transpose(data,(2,0,1)))




100%|██████████| 131072/131072 [1:08:48<00:00, 31.75it/s]


<Figure size 1000x1000 with 0 Axes>

In [None]:
b_Q_ret.max()

tensor(1.7778, device='cuda:0')

In [None]:
(p[H<0]*p[H<0].log()).mean()

NameError: name 'p' is not defined

In [None]:
(p[H<0]*p[H<0].log()).mean()

tensor(0.8005, device='cuda:0', grad_fn=<MeanBackward0>)

In [None]:
p[H<0][I[H<0]>0]

tensor([ 1.0462, 12.8096,  5.1820,  1.2462,  2.0290,  1.3864], device='cuda:0',
       grad_fn=<IndexBackward0>)

In [None]:
assert p>0

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

In [None]:
I

NameError: name 'I' is not defined