In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import time

from IPython.display import clear_output

from torch.utils.tensorboard import SummaryWriter
import socket
from datetime import datetime
import os

from agents import Agent
from replay_buffers import *
from utils import *

import copy

In [2]:
from environment import SimulationEnvironment0
num_blackholes=1
sim = SimulationEnvironment0(num_simulations=128,
                            num_blackholes=num_blackholes, 
                            force_constant=0.002, 
                            velocity_scale=0.01,
                            goal_threshold=0.05,
                            device='cuda')
states = sim.get_state()

In [3]:
import os
#os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
#torch.autograd.set_detect_anomaly(True)# 

In [4]:
x,y = torch.meshgrid(torch.arange(1000),torch.arange(1000))
pos = torch.stack([x.flatten(), y.flatten()],1)/1000
target_pos = torch.ones_like(pos)*0.25
bh_pos = torch.ones_like(pos)*0.75

st=torch.stack([pos,target_pos,bh_pos],1).cuda()

value = ((st[:,0:1]-st[:,2:]).norm(2,-1)<sim.crash_threshold)*sim.crash_reward +  ((st[:,0:1]-st[:,1:2]).norm(2,-1)<sim.goal_threshold)*sim.goal_reward
value = value.reshape(1000,1000)

ship_positions = st[:, 0, :]
goal_position = st[:, 1, :]
blackhole_positions = st[:, 2:, :]

E=[]

for i in range(1000):
    
    actions = torch.rand((1000*1000,2), device='cuda')*2-1

    goal_distance_before = torch.norm(ship_positions - goal_position, dim=1)
    distance = blackhole_positions - ship_positions.unsqueeze(1)
    inv_distance = 1 / torch.norm(distance, dim=2)
    direction = distance / torch.norm(distance, dim=2, keepdim=True)
    forces = sim.force_constant * direction * inv_distance.unsqueeze(2)
    is_crashed = distance.norm(dim=2) < sim.crash_threshold
    is_crashed = is_crashed.any(dim=1)
    forces[torch.isnan(forces)]=0
    forces[is_crashed]=0
    actions[is_crashed]=0
    ship_velocity = sim.velocity_scale * actions + forces.sum(dim=1)
    next_ship_positions = torch.clamp(ship_positions + ship_velocity, 0, 1)

    goal_distance_after = torch.norm(next_ship_positions - goal_position, dim=1)
    rewards = goal_distance_before - goal_distance_after
    is_goal_reached = goal_distance_before < sim.goal_threshold
    is_terminal = is_crashed | is_goal_reached

    pos, remainder =  (next_ship_positions*1000).floor(), (next_ship_positions*1000).frac()

    c0 = value[pos[:,0].clip(0,999).long(),pos[:,1].clip(0,999).long()]
    c1 = value[(pos[:,0]+1).clip(0,999).long(),pos[:,1].clip(0,999).long()]
    c2 = value[pos[:,0].clip(0,999).long(),(pos[:,1]+1).clip(0,999).long()]
    c3 = value[(pos[:,0]+1).clip(0,999).long(),(pos[:,1]+1).clip(0,999).long()]

    value_target = (c0*(1-remainder[:,0]) + c1*(remainder[:,0]))*(1-remainder[:,1]) + (c2*(1-remainder[:,0]) + c3*(remainder[:,0]))*(remainder[:,1])
    value_target = value_target.reshape(1000,1000)
    
    error = ((rewards.reshape(1000,1000)+0.98*value_target)-value)
    value += 0.2*error*(is_terminal.logical_not()).reshape(1000,1000)
    E.append(error.square().mean().item())

    # if i%100==99:
    #     clear_output()
    #     plt.figure(figsize=(10,5))
    #     plt.subplot(1,2,1)
    #     plt.imshow(value.cpu())
    #     plt.subplot(1,2,2)
    #     plt.plot(E)
    #     plt.show()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [5]:
import cv2
value_resized = torch.tensor(cv2.resize(value.cpu().numpy(), (100,100))).reshape(1,100,100)

In [6]:
states.shape

torch.Size([128, 3, 2])

In [7]:
from test_agents import *

In [8]:
2**10/128

8.0

In [9]:
%tensorboard --logdir runs

UsageError: Line magic function `%tensorboard` not found.


In [35]:
from torch.distributions import Categorical

LR = [3e-4]#1e-3,5e-4,2.5e-4,1e-4


# simulation
num_simulations = 2**12
num_blackholes = 1

# agent
hidden_size = 512
simlog_res = 255
simlog_half_res = simlog_res//2
simlog_max_range = 1
actions_res = 5
levels=2
input_type='complete'


# training
training_steps = 2**20
epochs=8
gamma = 0.98
smoothing = 1e-2
eps = 0.05

# replay buffers
use_prioritized = False
size = 2**16
batch_size = 2**10


plot = False



validate_every = 2**7

bin_values = (torch.arange(simlog_res)-simlog_half_res).cuda()/simlog_half_res*simlog_max_range
bin_values = bin_values.sign()*(bin_values.abs().exp()-1)

dec_x, dec_y = torch.meshgrid(torch.arange(5)/2-1, torch.arange(5)/2-1)
dec_x, dec_y = dec_x.flatten().cuda(), dec_y.flatten().cuda()

metric_idx = torch.pow(2,torch.arange(15))-1

fig, ax = plt.subplots(figsize=(10,10))


for lr in LR:

    experiment_name='test_replay_buffers'

    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
    log_dir = os.path.join(
        "runs",experiment_name, current_time + "_" + socket.gethostname() 
    )

    tb_writer = SummaryWriter(log_dir)

    sim = SimulationEnvironment0(num_simulations=128,
                            num_blackholes=num_blackholes, 
                            force_constant=0.002, 
                            velocity_scale=0.01,
                            goal_threshold=0.05,
                            max_steps=250,
                            device='cuda')

    agent = Agent((num_blackholes+2), hidden_size, levels, input_type, actor=False, value_dimension=simlog_res).cuda()
    actor = Agent((num_blackholes+2), hidden_size, levels, input_type, critic=False, action_dimension=5**2).cuda()

    optim = torch.optim.AdamW(agent.parameters(), lr=lr, weight_decay=1e-3)
    optim_actor = torch.optim.AdamW(actor.parameters(), lr=1e-4, weight_decay=1e-3)
    target_agent = copy.deepcopy(agent)
    target_actor = copy.deepcopy(actor)

    replay_buffer = Replay_Buffer(state_shape=((num_blackholes+2),2), action_shape=(1,), batch_size=batch_size, size=size, device='cuda')

    old_states=None

    

    R=[]

    from tqdm import tqdm
    pbar = tqdm(range(training_steps))

    x,y = torch.meshgrid(torch.arange(100),torch.arange(100))
    pos = torch.stack([x.flatten(), y.flatten()],1)/100
    target_pos = torch.ones_like(pos)*0.25
    bh_pos = torch.ones_like(pos)*0.75

    st=torch.stack([pos,target_pos,bh_pos],1)

    E = []
    VE = []


    for i in pbar:
        t0=time.time()

        # generate experience
        states = states.reshape(states.shape[0],-1).cuda()
        _, values = agent(states)
        actions, _ = actor(states)

        # assert not torch.any(torch.isnan(values))
        # assert not torch.any(torch.isinf(values.abs()))
        # assert not torch.any(torch.isnan((values-values.max(1)[0][:,None]).exp()))
        # assert not torch.any(torch.isinf((values-values.max(1)[0][:,None]).exp()))
        if(torch.any(torch.isnan((values-values.max(1)[0][:,None]).exp()))) or torch.any(torch.isinf((values-values.max(1)[0][:,None]).exp())):
            print('nan',torch.any(torch.isnan((values-values.max(1)[0][:,None]).exp())), 'inf', torch.any(torch.isinf((values-values.max(1)[0][:,None]).exp())))
        # assert not torch.any(torch.isnan(actions))
        # assert not torch.any(torch.isinf(actions.abs()))

        
        # m = actions[:,:,0]
        # s = torch.exp(actions[:,:,1])/10
        # print(m[0],s[0])

        # sampled_action = m + torch.randn_like(m) * s
        # action_prob = torch.exp(-torch.square((sampled_action-m)/s)/2) / (s*np.sqrt(2*np.pi))
#        means = actions[:,:2]
        # lower = torch.zeros(len(actions),2,2, device=actions.device)
        # lower[:,torch.tril_indices(2,2)[0], torch.tril_indices(2,2)[1]] = actions[:,2:]
        # lower += torch.eye(2, device='cuda')
        # lower = (lower.abs()+1e-5).sqrt() * lower.sign()
        # covariance = lower @ lower.transpose(-1,-2)
#        covariance = actions[:,2:].exp().diag_embed()/2

        # assert not torch.any(torch.isnan(covariance))
        # assert not torch.any(torch.isinf(covariance.abs()))

#        dist = torch.distributions.multivariate_normal.MultivariateNormal(means, covariance)
        #dist = torch.distributions.transformed_distribution.TransformedDistribution(dist, torch.distributions.transforms.TanhTransform())
        #H = dist.entropy()

        action_probs = actions.softmax(-1)
        H = -(action_probs*(action_probs+1e-8).log()).sum(-1)
        # action_probs.register_hook(print)
        dist = Categorical(action_probs)



        sampled_action = dist.sample()
        log_prob = dist.log_prob(sampled_action)
        
        # sampled_action_tanh = torch.tanh(sampled_action)
        # log_prob = log_prob - torch.log( 1- torch.tanh(sampled_action).square() ).sum(-1)

        # log_prob = torch.log(log_prob.exp()+1e-8)
        



        #log_prob.register_hook(print)
        u, v = dec_x[sampled_action], dec_y[sampled_action]
        
        sampled_action_decoded = torch.stack([u,v],1)

        rewards, new_states, is_terminal = sim.step(sampled_action_decoded)

        with torch.inference_mode():
            _, next_values = target_agent(new_states.reshape(new_states.shape[0],-1).cuda())

            expexted_next_value = (torch.softmax(next_values,1)@bin_values[:,None])[:,0]*(is_terminal.logical_not())

            expecter_target_value = rewards + gamma*expexted_next_value


            # actions_t, _ = target_actor(states)
            # means_t = actions_t[:,:2]
            # covariance_t = actions_t[:,2:].exp().diag_embed()
            # dist_t = torch.distributions.multivariate_normal.MultivariateNormal(means_t, covariance_t)
            # log_prob_t = dist_t.log_prob(sampled_action)
            # log_prob_t = log_prob_t - torch.log( 1- torch.tanh(sampled_action_tanh).square() ).sum(-1)

        
        # assert not torch.any(torch.isnan(expecter_target_value))
        # assert not torch.any(torch.isinf(expecter_target_value.abs()))

        y = two_hot_encode(expecter_target_value, simlog_max_range, simlog_res, simlog_half_res, smoothing=smoothing)
        
        # assert not torch.any(torch.isnan(y))
        # assert not torch.any(torch.isinf(y.abs()))
        
        critic_error = torch.nn.functional.cross_entropy(values, y, reduction='none')
        
        # assert not torch.any(torch.isnan(critic_error))
        # assert not torch.any(torch.isinf(critic_error.abs()))

        expexted_value = (torch.softmax(values,1)@bin_values[:,None])[:,0]
        expected_prediction_error = (expexted_value - expecter_target_value)
        actor_error = expected_prediction_error.detach() * log_prob - H*0.001 + 0.01*action_probs.square().mean(-1)#+ log_prob * 0.1 #- H * 1e-2
        
        assert not torch.any(torch.isnan(actor_error))
        assert not torch.any(torch.isinf(actor_error.abs()))

        #r = log_prob.exp()/(log_prob_t.exp()+1e-8)
        #actor_error = torch.maximum( expected_prediction_error.detach() * r, expected_prediction_error.detach() * r.clip(1-eps,1+eps)) + log_prob * 0. + (log_prob - log_prob_t).square()*1e-5


        error =  critic_error #+ actor_error*10

        optim.zero_grad()
        error.mean().backward()
        optim.step()

        optim_actor.zero_grad()
        actor_error.mean().backward()

        for parameter in actor.parameters():
            assert not torch.any(torch.isnan(parameter.grad))
            assert not torch.any(torch.isinf(parameter.grad.abs()))

        optim_actor.step()

        if isinstance(replay_buffer, Replay_Buffer):
            replay_buffer.add_experience(states=states, actions=sampled_action, action_probs=log_prob, rewards=rewards, next_states=new_states, terminals=is_terminal)
        elif isinstance(replay_buffer, Prioritized_Replay_Buffer):
            replay_buffer.add_experience(states=states, actions=sampled_action, action_probs=log_prob, rewards=rewards, next_states=new_states, terminals=is_terminal, weights=critic_error+1e-7) # 

        E.append(critic_error.mean().item())

        tb_writer.add_scalar('TD error', expected_prediction_error.mean().item(), i)
        tb_writer.add_scalar('Actor error', actor_error.mean().item(), i)
        tb_writer.add_scalar('Critic error', critic_error.mean().item(), i)
        tb_writer.add_scalar('Reward',rewards.mean().item(), i)

        # train actor with PPO
        old_probs = log_prob.exp().detach()
        for b_idx in range(0):

            actions, _ = actor(states)
            # assert not torch.any(torch.isnan(actions))
            # assert not torch.any(torch.isinf(actions.abs()))
#            means = actions[:,:2]
#            covariance = actions[:,2:].exp().diag_embed()/2
            # assert not torch.any(torch.isnan(covariance))
            # assert not torch.any(torch.isinf(covariance.abs()))
#            dist = torch.distributions.multivariate_normal.MultivariateNormal(means, covariance)

            action_probs = actions.softmax(-1)
            H = -(action_probs*(action_probs+1e-8).log()).sum(-1)
            # action_probs.register_hook(print)
            dist = Categorical(action_probs)

            #sampled_action = dist.sample()
            log_prob = dist.log_prob(sampled_action)
            
            #sampled_action = torch.tanh(sampled_action)
#            log_prob = log_prob - torch.log( 1- torch.tanh(sampled_action).square() ).sum(-1)

            probs = log_prob.exp()

            #r = probs/old_probs.detach()
            r = log_prob.exp()/(old_probs+1e-8)

            # other_points = torch.rand(16,128,2).cuda()*4-2
            # other_logprobs = dist.log_prob(other_points)
            # other_points_tanh = torch.tanh(other_points)
            # other_logprobs = other_logprobs - torch.log( 1- torch.tanh(other_points_tanh).square() ).sum(-1)
            #H = -( other_logprobs*other_logprobs.exp() ).mean(0)/2 - torch.log(log_prob.exp()+1e-8)/2 
            #    - H*0.001

            actor_error = torch.maximum( expected_prediction_error.detach() * r, expected_prediction_error.detach() * r.clip(1-eps,1+eps)) - H*0.01 # + (log_prob - log_prob_t).square()*0.001

            optim_actor.zero_grad()
            actor_error.mean().backward()
            optim_actor.step()

            
        # if (i+1)%16==0:
        #     for e in range(epochs):
        #         optim_actor.zero_grad()
        #         for b_idx in range(len(replay_buffer)):

        #             b_states, b_actions, b_log_prob, b_rewards, b_next_states, b_terminals = replay_buffer.get(b_idx) #, weights get_high_priority_batch

        #             _, values = agent(b_states.reshape(b_states.shape[0],-1).cuda())

                    

        #             actions, _ = actor(b_states.reshape(b_states.shape[0],-1).cuda())

        #             with torch.inference_mode():
        #                 _, next_values = target_agent(b_next_states.reshape(b_next_states.shape[0],-1).cuda())

        #                 expexted_next_value = (torch.softmax(next_values,1)@bin_values[:,None])[:,0]*(b_terminals.logical_not())

        #                 expecter_target_value = b_rewards + gamma*expexted_next_value

                                    
        #             y = two_hot_encode(expecter_target_value, simlog_max_range, simlog_res, simlog_half_res, smoothing=smoothing)

        #             critic_error = (torch.nn.functional.cross_entropy(values, y, reduction='none')) #/ weights off_policy_coeff.detach()*

        #             optim.zero_grad()
        #             critic_error.mean().backward()
        #             optim.step()
                    

        #             expexted_value = (torch.softmax(values,1)@bin_values[:,None])[:,0]
        #             expected_prediction_error = (expexted_value - expecter_target_value)

        #             # assert not torch.any(torch.isnan(actions))
        #             # assert not torch.any(torch.isinf(actions.abs()))
        #             means = actions[:,:2]
        #             covariance = actions[:,2:].exp().diag_embed()/2
        #             # assert not torch.any(torch.isnan(covariance))
        #             # assert not torch.any(torch.isinf(covariance.abs()))
        #             dist = torch.distributions.multivariate_normal.MultivariateNormal(means, covariance)

        #             #sampled_action = dist.sample()
        #             log_prob = dist.log_prob(b_actions)
                    
        #             #sampled_action = torch.tanh(sampled_action)
        #             log_prob = log_prob - torch.log( 1- torch.tanh(b_actions).square() ).sum(-1)

        #             #r = probs/old_probs.detach()
        #             r = log_prob.exp()/(b_log_prob.exp()+1e-8)

        #             # other_points = torch.rand(16,128,2).cuda()*4-2
        #             # other_logprobs = dist.log_prob(other_points)
        #             # other_points_tanh = torch.tanh(other_points)
        #             # other_logprobs = other_logprobs - torch.log( 1- torch.tanh(other_points_tanh).square() ).sum(-1)
        #             #H = -( other_logprobs*other_logprobs.exp() ).mean(0)/2 - torch.log(log_prob.exp()+1e-8)/2 
        #             #    - H*0.001

        #             actor_error = torch.maximum( expected_prediction_error.detach() * r, expected_prediction_error.detach() * r.clip(1-eps,1+eps))# + (log_prob - log_prob_t).square()*0.001

        #             actor_error.mean().backward()
        #         optim_actor.step()


        update_target_model(model=agent, target_model=target_agent, decay=1e-3)
        update_target_model(model=actor, target_model=target_actor, decay=1e-4)


        # train critic on old experience
        for b_idx in range(0):
            b_states, b_actions, b_log_prob, b_rewards, b_next_states, b_terminals, weights = replay_buffer.get_batch() #, weights get_high_priority_batch

            _, values = agent(b_states.reshape(b_states.shape[0],-1).cuda())

            # assert not torch.any(torch.isnan(values))
            # assert not torch.any(torch.isinf(values.abs()))
            # assert not torch.any(torch.isnan((values-values.max(1)[0][:,None]).exp()))
            # assert not torch.any(torch.isinf((values-values.max(1)[0][:,None]).exp()))
            if(torch.any(torch.isnan((values-values.max(1)[0][:,None]).exp()))) or torch.any(torch.isinf((values-values.max(1)[0][:,None]).exp())):
                print('nan',torch.any(torch.isnan((values-values.max(1)[0][:,None]).exp())), 'inf', torch.any(torch.isinf((values-values.max(1)[0][:,None]).exp())))

            #actions, _ = actor(b_states.reshape(b_states.shape[0],-1).cuda())

            means = actions[:,:2]
            covariance = actions[:,2:].exp().diag_embed()
            dist = torch.distributions.multivariate_normal.MultivariateNormal(means, covariance)

            # lower = torch.zeros(len(actions),2,2, device=actions.device)
            # lower[:,torch.tril_indices(2,2)[0], torch.tril_indices(2,2)[1]] = actions[:,2:]
            # lower += torch.eye(2, device='cuda')
            # lower = (lower.abs()+1e-5).sqrt() * lower.sign()
            # covariance = lower @ lower.transpose(-1,-2) 
            

            # sampled_action = dist.sample()
            # sampled_action = torch.tanh(sampled_action)

            # log_prob = dist.log_prob(b_actions)
            # log_prob = log_prob - torch.log( 1- torch.tanh(b_actions).square() ).sum(-1)

            # actions = actions.reshape(actions.shape[0],2,-1)
            # actions/=2
            # action_probs = actions.softmax(-1)
            # action_prob = action_probs.gather(-1, ((b_actions+1)/2*5).long()[:,:,None]).prod(1).squeeze(1)
            # # assert not torch.any(torch.isnan(action_prob))
            
            # off_policy_coeff = torch.exp((log_prob - b_log_prob).clip(-100,5))
            # off_policy_coeff/=off_policy_coeff.sum()

            with torch.inference_mode():
                _, next_values = target_agent(b_next_states.reshape(b_next_states.shape[0],-1).cuda())

                expexted_next_value = (torch.softmax(next_values,1)@bin_values[:,None])[:,0]*(b_terminals.logical_not())

                expecter_target_value = b_rewards + gamma*expexted_next_value

            
            # assert not torch.any(torch.isnan(expecter_target_value))
            # assert not torch.any(torch.isinf(expecter_target_value.abs()))

            y = two_hot_encode(expecter_target_value, simlog_max_range, simlog_res, simlog_half_res, smoothing=smoothing)
            
            # assert not torch.any(torch.isnan(y))
            # assert not torch.any(torch.isinf(y.abs()))
            
            critic_error = (torch.nn.functional.cross_entropy(values, y, reduction='none')) #/ weights off_policy_coeff.detach()*

            # assert not torch.any(torch.isnan(critic_error))
            # assert not torch.any(torch.isinf(critic_error.abs()))

            expexted_value = (torch.softmax(values,1)@bin_values[:,None])[:,0]
            expected_prediction_error = (expexted_value - expecter_target_value)
            #actor_error = (expected_prediction_error.detach() * log_prob - log_prob * 1e-2) #- H * 1e-2 off_policy_coeff.detach()*

            if weights is not None:
                critic_error /= weights
                #actor_error /= weights

            # actor_error = off_policy_coeff.detach() * expected_prediction_error.detach() * torch.log(action_prob)

            error =  critic_error #+ actor_error*0.01

            optim.zero_grad()
            error.mean().backward()
            optim.step()

            
            # with torch.inference_mode():
            #     actions_t, _ = target_actor(b_states.reshape(b_states.shape[0],-1).cuda())
            #     means_t = actions_t[:,:2]
            #     covariance_t = actions_t[:,2:].exp().diag_embed()
            #     dist_t = torch.distributions.multivariate_normal.MultivariateNormal(means_t, covariance_t)
            #     log_prob_t = dist_t.log_prob(b_actions)
            #     log_prob_t = log_prob_t - torch.log( 1- torch.tanh(b_actions.tanh()).square() ).sum(-1)

            # actions, _ = actor(b_states.reshape(b_states.shape[0],-1).cuda())
            # means = actions[:,:2]
            # covariance = actions[:,2:].exp().diag_embed()/2
            # dist = torch.distributions.multivariate_normal.MultivariateNormal(means, covariance)
            # log_prob = dist.log_prob(b_actions)
            # log_prob = log_prob - torch.log( 1- torch.tanh(b_actions.tanh()).square() ).sum(-1)

            # probs = log_prob.exp()

            # #r = probs/old_probs.detach()
            # r = log_prob.exp()/(log_prob_t.exp()+1e-8)

            # # other_points = torch.rand(16,128,2).cuda()*4-2
            # # other_logprobs = dist.log_prob(other_points)
            # # other_points_tanh = torch.tanh(other_points)
            # # other_logprobs = other_logprobs - torch.log( 1- torch.tanh(other_points_tanh).square() ).sum(-1)
            # #H = -( other_logprobs*other_logprobs.exp() ).mean(0)/2 - torch.log(log_prob.exp()+1e-8)/2 
            # #    - H*0.001

            # actor_error = torch.maximum( expected_prediction_error.detach() * r, expected_prediction_error.detach() * r.clip(1-eps,1+eps)) + (log_prob - log_prob_t).square()*0.001

            # optim_actor.zero_grad()
            # actor_error.mean().backward()
            # optim_actor.step()

            # optim_actor.zero_grad()
            # actor_error.mean().backward()
            # optim_actor.step()

            if weights is not None:
                replay_buffer.update_weights(critic_error+1e-7)


            update_target_model(model=agent, target_model=target_agent, decay=1e-3)

        states = new_states

        if i % 8 == 0:
            pbar.set_postfix_str(#f'{pred_error.mean().item():.3g}'.ljust(simlog_max_range)+
                                #f'{actor_error.mean().item():.3g}'.ljust(simlog_max_range)+
                                f'{error.mean().item():.3g}'.ljust(10)+
                                f'{rewards.mean().item():.3g}'.ljust(10))

        R.append(rewards.mean().item())


        
        if (i+1) % validate_every == 0:

            V = []
            A = []
            for b in range((len(st)+batch_size-1)//batch_size):
                stb = st[b*batch_size:(b+1)*batch_size]
                _, v = agent(stb.reshape(stb.shape[0],-1).cuda())
                a, _ = actor(stb.reshape(stb.shape[0],-1).cuda())
                V.append(v)
                A.append(a)
            V = torch.concat(V,0)
            A = torch.concat(A,0)

            V_t = []
            for b in range((len(st)+batch_size-1)//batch_size):
                stb = st[b*batch_size:(b+1)*batch_size]
                _, v1 = target_agent(stb.reshape(stb.shape[0],-1).cuda())
                #_, v2 = agent2(st.reshape(st.shape[0],-1).cuda())
                V_t.append(v1)
            V_t = torch.concat(V_t,0)
            
            
            #plt.plot(V[0].detach().cpu())
            # V = (V.softmax(1).detach().cpu() * torch.arange(simlog_res)).sum(1) - simlog_half_res
            # V = V/simlog_half_res*simlog_max_range
            # V = (V.abs().exp()-1)*V.sign()
            V = (V.softmax(1)@bin_values[:,None])[:,0].detach().cpu()

            # V_t = (V_t.softmax(1).detach().cpu() * torch.arange(simlog_res)).sum(1) - simlog_half_res
            # V_t = V_t/simlog_half_res*simlog_max_range
            # V_t = (V_t.abs().exp()-1)*V_t.sign()
            V_t = (V_t.softmax(1)@bin_values[:,None])[:,0].detach().cpu()
            
            #plt.show()
            #print(V.mean())
            #input()

            # clear_output()

            # print(V.mean())
            # fig = plt.figure(figsize=(15,16))
            # gs = GridSpec(8, 6, figure=fig)

            # plt.subplot(gs[:2,:2])
            # plt.imshow(V.reshape(100,100),vmin=-1, vmax=1)
            tb_writer.add_image('V', (V.reshape(1,100,100)/2+0.5), i)

            # plt.subplot(gs[:2,2:4])
            # plt.imshow(V_t.reshape(100,100),vmin=-1, vmax=1)
            # plt.colorbar()
            tb_writer.add_image('V_t', V_t.reshape(1,100,100)/2+0.5, i)

            # plt.subplot(gs[:2,4:])
            # plt.imshow((torch.tensor(value_resized)-V.reshape(100,100)).square())
            # plt.colorbar()
            #tb_writer.add_image('Validation Value Error', (value_resized-V.reshape(1,100,100)).square().clip(0,1), i)

            # plt.subplot(gs[4:,1:5])

            pos = st[:,0].reshape(100,100,-1)[::2,::2]
            A = A.reshape(100,100,-1)[::2,::2]
            # A = A/2
            action_probs = A.softmax(-1)
            u = action_probs@dec_x[:,None]
            v = action_probs@dec_y[:,None]
            A = torch.concat([u,v],-1)
            
            plt.quiver(pos[...,1].flatten(), -pos[...,0].flatten(), A[...,1].tanh().detach().cpu().flatten(), -A[...,0].tanh().detach().cpu().flatten(), color='g',scale=50, headwidth=2)
            ax.axis('off')
            plt.gca().set_aspect('equal')
            plt.subplots_adjust(0,0,1,1,0,0)
            fig.canvas.draw()
            data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
            data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            tb_writer.add_image('Policy visualization', np.transpose(data,(2,0,1)) , i)
            plt.clf()
            #plt.show()

            # plt.subplot(gs[2:4,:2])
            # plt.plot(E)
            # plt.plot(np.convolve(np.array(E), np.ones(100),'same')/np.convolve(np.ones(len(E)), np.ones(100),'same'))
            # plt.yscale('log')
            

            
            # plt.subplot(gs[2:4,2:4])
            # plt.plot(R)
            # plt.plot(np.convolve(np.array(R), np.ones(100),'same')/np.convolve(np.ones(len(R)), np.ones(100),'same'))
            
            # plt.subplot(gs[2:4,4:])
            val_err = (value_resized-V.reshape(1,100,100)).square().mean().item()
            VE.append(val_err)
            # plt.plot(VE)
            #plt.show()
            tb_writer.add_scalar('Validation Error',val_err, i)

        #tb_writer.add_scalar('Iteration time',time.time()-t0, i)
    
    indices = metric_idx[metric_idx<len(VE)]
    tb_writer.add_hparams(  {"lr":lr,
                            "epochs":epochs,
                            "gamma": gamma,
                            "smoothing":smoothing,
                             },
                            {f"error_at_{int((indices[i]+1)*validate_every)}": ve.item() for i,ve in enumerate(torch.tensor(VE)[indices])})
    

    torch.save({
    "E":torch.tensor(E),
    "VE":torch.tensor(VE),
    },
    os.path.join(log_dir,'results.pth'))



 10%|█         | 109570/1048576 [34:29<8:30:25, 30.66it/s, 0.944     0.00781   ]

In [31]:
actions[-3]

tensor([ 23.4095,  12.3554,  33.1925,  25.2574,  72.6465,   5.9458,   2.1978,
         19.1969,  33.1635,  25.4218, -12.5538,   0.2201,  10.0015,  24.0330,
         30.6597, -18.6875, -15.5374,  10.9901,   7.9407,  20.0727, -35.9163,
        -13.3782,  -9.1114,   5.9383,  14.0954], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [27]:
actions[-3].softmax(-1)

tensor([4.1366e-22, 6.5453e-27, 7.3345e-18, 2.6252e-21, 1.0000e+00, 1.0772e-29,
        2.5382e-31, 6.1254e-24, 7.1248e-18, 3.0945e-21, 9.9536e-38, 3.5127e-32,
        6.2178e-28, 7.7165e-22, 5.8258e-19, 2.1584e-40, 5.0377e-39, 1.6710e-27,
        7.9184e-29, 1.4707e-23, 0.0000e+00, 4.3647e-38, 3.1117e-36, 1.0690e-29,
        3.7289e-26], device='cuda:0', grad_fn=<SoftmaxBackward0>)

In [30]:
action_probs[-3]*action_probs[-3].log()

tensor([-2.0367e-20, -3.9462e-25, -2.8938e-16, -1.2441e-19,  0.0000e+00,
        -7.1847e-28, -1.7881e-29, -3.2740e-22, -2.8131e-16, -1.4614e-19,
        -8.4805e-36, -2.5441e-30, -3.8951e-26, -3.7513e-20, -2.4461e-17,
        -1.9714e-38, -4.4424e-37, -1.0303e-25, -5.1237e-27, -7.7319e-22,
                nan, -3.7547e-36, -2.5440e-34, -7.1313e-28, -2.1833e-24],
       device='cuda:0', grad_fn=<MulBackward0>)

In [24]:
H[-3]

tensor(nan, device='cuda:0', grad_fn=<SelectBackward0>)

In [15]:
old_probs

tensor([0.0472, 0.0428, 0.8490, 0.9992, 0.0380, 0.0313, 0.0405, 0.0374, 0.9998,
        0.0341, 0.0621, 0.9937, 0.0402, 0.0434, 0.0371, 0.0443, 0.0421, 0.0293,
        0.9920, 0.0387, 0.0429, 0.0481, 0.0431, 0.9970, 0.0365, 0.0413, 0.0451,
        0.0419, 0.0427, 0.0404, 0.0395, 0.1285, 0.9669, 0.0497, 0.0360, 0.0440,
        0.0347, 0.0334, 0.0421, 0.0331, 0.5921, 0.0456, 0.7926, 0.0416, 0.9987,
        0.0368, 0.9901, 0.8634, 0.0373, 1.0000, 0.0396, 0.0457, 0.0393, 0.0379,
        0.0440, 0.0354, 0.0576, 0.0416, 0.9999, 0.0359, 0.0329, 0.0419, 0.0407,
        0.0398, 0.0473, 0.0420, 0.0397, 0.0414, 0.0485, 0.0404, 0.9208, 0.0374,
        0.0458, 0.0466, 0.0467, 0.0402, 0.0540, 0.0458, 0.0451, 0.0386, 0.2877,
        0.0401, 0.0475, 1.0000, 0.0416, 0.0476, 0.0389, 0.0445, 0.0452, 0.0480,
        0.0475, 0.0447, 0.0477, 0.0395, 0.0495, 0.0353, 0.0385, 0.0458, 0.0365,
        1.0000, 0.9640, 0.0360, 0.0357, 1.0000, 0.0343, 0.0523, 0.0341, 0.9988,
        0.0376, 0.9999, 0.0357, 0.0382, 