In [23]:
# import os
# os.listdir('../input/rl-project')
# import sys
# sys.path.insert(0,'../input/rl-project/')

In [24]:
import torch
from torch.distributions import Categorical
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import time
import math
import cv2
from tqdm import tqdm

from IPython.display import clear_output

from torch.utils.tensorboard import SummaryWriter
import socket
from datetime import datetime
import os

from agents import Agent
from environment import SimulationEnvironment0
from replay_buffers import *
from utils import *

import copy
experiment_name='final_1_dist'

In [26]:
torch.autograd.set_detect_anomaly(False)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x27786a7b250>

In [27]:
def compute_dist_discrete(actions):
    action_probs = actions.softmax(-1)
    dist = Categorical(action_probs)
    return dist

def compute_dist_continuous(actions):
    
    means = actions[:,:2]
    correlations = actions[:,2:-2].tanh()*(1-1e-5)
    variances = actions[:,-2:].exp()

    b,n = means.shape
    corr_matrix = torch.zeros(b, n, n, device=means.device)
    indices = torch.tril_indices(n, n, offset=-1)
    corr_matrix[..., indices[0], indices[1]] = correlations
    corr_matrix[..., indices[1], indices[0]] = correlations
    corr_matrix[..., torch.arange(n), torch.arange(n)] = 1
    cov_matrix = corr_matrix * variances[...,None] * variances[:,None]
    
    dist = torch.distributions.multivariate_normal.MultivariateNormal(means, cov_matrix)
    
    return dist

In [28]:
def decode_action_discrete(dist, sampled_action):
    # WARNING: uses global dec_x dec_y to decode the actions!
    log_prob = dist.log_prob(sampled_action)
    sampled_action_decoded = torch.stack([dec_x[sampled_action],dec_y[sampled_action]],1)
    return sampled_action_decoded, log_prob

def decode_action_continuous(dist, sampled_action):
    log_prob_sample = dist.log_prob(sampled_action)
    sampled_action_decoded = sampled_action.tanh()
    log_prob = log_prob_sample - torch.log( 1- sampled_action_decoded.square() + 1e-8 ).sum(-1) # correct accounting for the tanh transform
    return sampled_action_decoded, log_prob

In [29]:
def update_models(decay=1e-3):
    update_target_model(model=actor, target_model=target_actor,decay=decay)
    update_target_model(model=V, target_model=target_V, decay=decay)
    update_target_model(model=Q, target_model=target_Q, decay=decay)

def flatten_sequences(X, removelast=False):
    for i in range(len(X)):
        if removelast:
            X[i] = X[i][:,:-1]
        s = X[i].shape
        if len(s)==2:
            X[i] = X[i].reshape(s[0]*s[1])
        else:
            X[i] = X[i].reshape(s[0]*s[1],-1)
    return X

def reshape_sequences(X, shape):
    for i in range(len(X)):
        X[i] = X[i].reshape(shape)
    return X

def get_batch(X, batch_size, b_idx):
    start = b_idx*batch_size
    end = (b_idx+1)*batch_size
    for i in range(len(X)):
        X[i] = X[i][start:end]
    return X

def initialize_zeros(shape, n, device):
    X=[]
    for _ in range(n):
        X.append(torch.zeros(shape, device=device))
    return X

In [33]:
seed = 0

EXPERIMENTS = [#{"entropy": 3e-4, 'PPO':4, 'replay_ratio':0, 'training_steps':2**17},
               {"entropy": 3e-4, 'PPO':2, 'replay_ratio':1, 'training_steps':2**17},
               {"entropy": 3e-4, 'PPO':2, 'replay_ratio':3, 'training_steps':2**17}
               ]


# simulation
num_simulations = 128
num_blackholes = 1

# agent
hidden_size = 512
simlog_res = 255
use_symlog = True
discrete_actions = False
simlog_half_res = simlog_res//2
simlog_max_range = 1
actions_res = 5
levels=2
input_type = 'complete'

lr = 3e-4
lr_actor = 3e-5



# training
#training_steps = 2**18
#epochs=8
lamb = 0.8
gamma = 0.98
smoothing = 1e-2
eps = 0.05 # for PPO update
seg_len = 2**5
h_samples = 0
c=10


# replay buffers
num_steps = 1024
batch_size = 2**10
replay_batch_size = num_simulations

plot = False

validate_every = 2**9

bin_values = (torch.arange(simlog_res)-simlog_half_res).cuda()/simlog_half_res*simlog_max_range
bin_values = bin_values.sign()*(bin_values.abs().exp()-1)

#dec_x, dec_y = torch.meshgrid(torch.arange(actions_res)/(actions_res-1)*2-1, torch.arange(actions_res)/(actions_res-1)*2-1)
#dec_x, dec_y = dec_x.flatten().cuda(), dec_y.flatten().cuda()

metric_idx = torch.pow(2,torch.arange(15))-1

fig, ax = plt.subplots(figsize=(10,10))



for experiment in EXPERIMENTS:

    # set seed
    torch.manual_seed(seed)

    # initialize logging (with tensorboard)
    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
    log_dir = os.path.join(
        "runs",experiment_name, current_time + "_" + socket.gethostname() 
    )
    tb_writer = SummaryWriter(log_dir)

    # initialize the simulation
    sim = SimulationEnvironment0(num_simulations=num_simulations,
                            num_blackholes=num_blackholes, 
                            force_constant=0.002, 
                            velocity_scale=0.01,
                            goal_threshold=0.05,
                            max_steps=250,
                            device='cuda')
    states = sim.get_state()

    # initialize the networks
    if discrete_actions:
        action_dim = actions_res**2
        compute_dist = compute_dist_discrete
        decode_action = decode_action_discrete
    else:
        action_dim = 5
        compute_dist = compute_dist_continuous
        decode_action = decode_action_continuous

    if use_symlog:
        actor = Agent((num_blackholes+2)*2, hidden_size, levels, input_type, critic=False, action_dimension=action_dim).cuda()
        actor2 = Agent((num_blackholes+2)*2, hidden_size, levels, input_type, critic=False, action_dimension=action_dim).cuda()
        V = Agent((num_blackholes+2)*2, hidden_size, levels, input_type, actor=False, value_dimension=simlog_res).cuda()
        Q = Agent(hidden_size + 2, hidden_size, levels, input_type='base', actor=False, value_dimension=simlog_res).cuda()
    else:
        actor = Agent((num_blackholes+2)*2, hidden_size, levels, input_type, critic=False, action_dimension=action_dim).cuda()
        actor2 = Agent((num_blackholes+2)*2, hidden_size, levels, input_type, critic=False, action_dimension=action_dim).cuda()
        V = Agent((num_blackholes+2)*2, hidden_size, levels, input_type, actor=False, value_dimension=1).cuda()
        Q = Agent(hidden_size + 2, hidden_size, levels, input_type='base', actor=False, value_dimension=1).cuda()

    optim_actor = torch.optim.AdamW(list(actor.parameters())+list(actor2.parameters()), lr=lr_actor, weight_decay=1e-3)
    optim_critic = torch.optim.AdamW(list(V.parameters())+list(Q.parameters()), lr=lr, weight_decay=1e-3)
    target_actor = copy.deepcopy(actor)
    target_V = copy.deepcopy(V)
    target_Q = copy.deepcopy(Q)

    # initialize the replay buffer
    replay_buffer = Replay_Buffer_Segments(state_shape=((num_blackholes+2),2), action_shape=(2,), params_shape=(6,), segment_lenght=seg_len, num_simulations=num_simulations, num_steps=num_steps, batch_size=replay_batch_size, device='cuda')

    # initialize vailidation plane
    x,y = torch.meshgrid(torch.arange(100),torch.arange(100))
    pos = torch.stack([x.flatten(), y.flatten()],1)/100

    target_pos = torch.ones_like(pos)*0.25      # position of the target
    bh_pos = torch.ones_like(pos)*0.75          # position of the blackholes

    st=torch.stack([pos,target_pos,bh_pos],1)

    pbar = tqdm(range(experiment['training_steps']))
    for i in pbar:
        t0=time.time()

        # GENERATE EXPERIENCE
        with torch.inference_mode():
            
            states = states.reshape(states.shape[0],-1).cuda()

            actions, _, _ = actor(states)
            dist = compute_dist(actions)

            sampled_action = dist.sample()
            log_prob_sample = dist.log_prob(sampled_action)

            # apply the tanh trans
            sampled_action_decoded = sampled_action.tanh()
            log_prob = log_prob_sample - torch.log( 1- sampled_action_decoded.square() + 1e-8 ).sum(-1)

            rewards, new_states, terminals = sim.step(sampled_action_decoded)

            replay_buffer.add_experience(states.reshape(new_states.shape), sampled_action, torch.concat([cov_matrix.flatten(1),means],-1), log_prob.exp(), rewards, terminals)

        tb_writer.add_scalar('Reward',rewards.mean().item(), i)

        if (i+1) % seg_len == 0:

            # ON POLICY UPDATE

            seg_state, seg_actions, _, seg_action_probs, seg_rewards, seg_terminal = replay_buffer.get_last_segment()

            with torch.inference_mode():

                # get target values for all the states in the segment
                V_t = torch.zeros(num_simulations*seg_len, device='cuda')
                for b_idx in range((seg_len*num_simulations+batch_size-1)//batch_size):

                    b_state = seg_state.reshape((seg_len)*num_simulations,-1)[b_idx*batch_size:(b_idx+1)*batch_size]
                    
                    _, Vs, _= target_V(b_state)
                    V_t[b_idx*batch_size:(b_idx+1)*batch_size] = (torch.softmax(Vs,1)@bin_values[:,None])[:,0]
        
                V_t = V_t.reshape(num_simulations, seg_len)

                # compute GAE and TD lambda returns
                gae = torch.zeros_like(V_t)
                for t in reversed(range(seg_len-1)):

                    d_t = -V_t[:,t] + seg_rewards[:,t] + gamma*V_t[:,t+1]*seg_terminal[:,t].logical_not()

                    gae[:,t] = d_t + gamma*lamb*gae[:,t+1]*seg_terminal[:,t].logical_not()

                tdl = V_t + gae


            #with torch.autograd.set_detect_anomaly(True):

            # run PPO for n epochs
            for ppo_epoch in range(experiment['PPO']):
                for b_idx in range((seg_len*num_simulations+batch_size-1)//batch_size):

                    b_gae = gae[:,:-1].reshape((seg_len-1)*num_simulations)[b_idx*batch_size:(b_idx+1)*batch_size]
                    b_tdl = tdl[:,:-1].reshape((seg_len-1)*num_simulations)[b_idx*batch_size:(b_idx+1)*batch_size]
                    b_state = seg_state[:,:-1].reshape((seg_len-1)*num_simulations,-1)[b_idx*batch_size:(b_idx+1)*batch_size]
                    b_action = seg_actions[:,:-1].reshape((seg_len-1)*num_simulations,-1)[b_idx*batch_size:(b_idx+1)*batch_size]
                    b_prob = seg_action_probs[:,:-1].reshape((seg_len-1)*num_simulations)[b_idx*batch_size:(b_idx+1)*batch_size]

                    actions, _, _ = actor(b_state)
                    dist = compute_dist(actions)

                    sampled_action_decoded = b_action.tanh()
                    _, Vs, hs = V(b_state)
                    _, Qs, _ = Q(torch.concat([hs, sampled_action_decoded],-1))


                    assert not torch.any(torch.isnan(actions))
                    assert not torch.any(torch.isnan(Vs))
                    assert not torch.any(torch.isnan(Qs))
                    assert not torch.any(torch.isinf(actions))
                    assert not torch.any(torch.isinf(Vs))
                    assert not torch.any(torch.isinf(Qs))

                    
                    log_prob_sample = dist.log_prob(b_action)
                    log_prob = log_prob_sample - torch.log( 1- sampled_action_decoded.square() + 1e-8 ).sum(-1)
                    probs = log_prob.exp()

                    r = (probs + 1e-5)/(b_prob + 1e-5)
                    L = torch.minimum(b_gae*r, b_gae*r.clip(1-eps,1+eps))

                    # # approximation of the hentropy
                    if h_samples>0:
                        s = dist.sample((h_samples,))
                        log_prob = dist.log_prob(s) - torch.log(1 - s.tanh().square() + 1e-8).sum(-1)
                        H = 0.5 -log_prob.exp().mean(0)
                    else:
                        H = -log_prob.exp()

                    actor_error = - L - H*experiment['entropy'] #- L2 + (variances2+1e-8).log().square().mean(-1)

                    y = two_hot_encode(b_tdl, simlog_max_range, simlog_res, simlog_half_res, smoothing=smoothing)
                    critic_error_V = torch.nn.functional.cross_entropy(Vs, y, reduction='none')
                    critic_error_Q = torch.nn.functional.cross_entropy(Vs+Qs, y, reduction='none')
                    critic_error = (critic_error_V + critic_error_Q) / 2
                    error = actor_error + critic_error

                    assert not torch.any(torch.isnan(actor_error))
                    assert not torch.any(torch.isnan(critic_error))
                    assert not torch.any(torch.isinf(actor_error))
                    assert not torch.any(torch.isinf(critic_error))

                    optim_critic.zero_grad()
                    critic_error.mean().backward(retain_graph=True)
                    optim_actor.zero_grad() # also delete the derivative of the critic with respect to the actor through Q
                    # for param in Q.parameters(): # turn off the gradient saving for the critic to not update Q with the actor loss
                    #     param.requires_grad = False
                    actor_error.mean().backward()
                    # for param in Q.parameters():
                    #     param.requires_grad = True
                    optim_critic.step()
                    optim_actor.step()
                
                update_target_model(model=actor, target_model=target_actor, decay=5e-3)
                update_target_model(model=V, target_model=target_V, decay=5e-3)
                update_target_model(model=Q, target_model=target_Q, decay=5e-3)
            tb_writer.add_scalar('critic_error_Q',critic_error_Q.mean().item()-critic_error_V.mean().item(), i)
            tb_writer.add_scalar('hentropy',H.mean().item(), i)

            # REPLAY EXPERIENCES

            # replay experiences
            for replay_epoch in range(experiment['replay_ratio']):

                seg_state, seg_actions, seg_action_params, seg_action_probs, seg_rewards, seg_terminal = replay_buffer.get_batch()

                with torch.inference_mode():

                    # get target values and action probs for all the states in the segment

                    V_t = torch.zeros(replay_batch_size*seg_len, device='cuda')
                    Q_t = torch.zeros(replay_batch_size*seg_len, device='cuda')
                    probs_t = torch.zeros(replay_batch_size*seg_len, device='cuda')
                    Q_t_corr = torch.zeros(replay_batch_size*seg_len, device='cuda')
                    ro_corr = torch.zeros(replay_batch_size*seg_len, device='cuda')

                    for b_idx in range((seg_len*replay_batch_size+batch_size-1)//batch_size):

                        b_state = seg_state.reshape((seg_len)*replay_batch_size,-1)[b_idx*batch_size:(b_idx+1)*batch_size]
                        b_action = seg_actions.reshape((seg_len)*replay_batch_size,-1)[b_idx*batch_size:(b_idx+1)*batch_size]
                        b_seg_action_params = seg_action_params.reshape((seg_len)*replay_batch_size,-1)[b_idx*batch_size:(b_idx+1)*batch_size]

                        actions, _, _ = target_actor(b_state)
                        dist = compute_dist(actions)

                        log_prob_sample = dist.log_prob(b_action)
                        sampled_action_decoded = b_action.tanh()
                        log_prob = log_prob_sample - torch.log( 1- sampled_action_decoded.square() + 1e-8 ).sum(-1)

                        sampled_action_corr = dist.sample()
                        log_prob_sample_corr = dist.log_prob(sampled_action_corr)
                        sampled_action_decoded_corr = sampled_action_corr.tanh()
                        log_prob_corr = log_prob_sample_corr - torch.log( 1- sampled_action_decoded_corr.square() + 1e-8 ).sum(-1)

                        seg_corr = b_seg_action_params[:,:-2].reshape(-1,2,2)
                        seg_means = b_seg_action_params[:,-2:]
                        seg_dist = torch.distributions.multivariate_normal.MultivariateNormal(seg_means, seg_corr)
                        seg_log_prob_sample_corr = seg_dist.log_prob(sampled_action_corr)
                        seg_log_prob_corr = seg_log_prob_sample_corr - torch.log( 1- sampled_action_decoded_corr.square() + 1e-8 ).sum(-1)

                        _, Vs, hs = target_V(b_state)
                        _, Qs, _ = target_Q(torch.concat([hs, sampled_action_decoded],-1))
                        _, Qs_corr, _ = target_Q(torch.concat([hs, sampled_action_decoded_corr],-1))

                        assert not torch.any(torch.isnan(actions))
                        assert not torch.any(torch.isnan(Vs))
                        assert not torch.any(torch.isnan(Qs))
                        assert not torch.any(torch.isinf(actions))
                        assert not torch.any(torch.isinf(Vs))
                        assert not torch.any(torch.isinf(Qs))

                        V_t[b_idx*batch_size:(b_idx+1)*batch_size] = (torch.softmax(Vs,1)@bin_values[:,None])[:,0]
                        Q_t[b_idx*batch_size:(b_idx+1)*batch_size] = (torch.softmax(Vs+Qs,1)@bin_values[:,None])[:,0]
                        probs_t[b_idx*batch_size:(b_idx+1)*batch_size] = log_prob.exp()
                        Q_t_corr[b_idx*batch_size:(b_idx+1)*batch_size] = (torch.softmax(Vs+Qs_corr,1)@bin_values[:,None])[:,0]
                        ro_corr[b_idx*batch_size:(b_idx+1)*batch_size] = (log_prob_corr.exp() + 1e-5)/(seg_log_prob_corr.exp() + 1e-5)

                    V_t = V_t.reshape(replay_batch_size, seg_len)
                    Q_t = Q_t.reshape(replay_batch_size, seg_len)
                    probs_t = probs_t.reshape(replay_batch_size, seg_len)
                    Q_t_corr = Q_t_corr.reshape(replay_batch_size, seg_len)
                    ro_corr = ro_corr.reshape(replay_batch_size, seg_len)

                    # compute targets (as in RETRACE)

                    Q_ret = torch.zeros_like(V_t)
                    V_target = torch.zeros_like(V_t)
                    corr = (1-c/ro_corr).relu()*(Q_t_corr - V_t)

                    Q_ret[:,-1] = Q_t[:,-1]

                    for t in reversed(range(seg_len-1)):

                        ro = (probs_t[:,t+1] + 1e-5)/(seg_action_probs[:,t+1] + 1e-5) # ro of t+1
                        ci = lamb * torch.minimum(torch.ones(1, device='cuda'), ro)
                        
                        assert not torch.any(torch.isnan(ci))
                        assert not torch.any(torch.isnan(ro))
                        assert not torch.any(torch.isinf(ci))
                        assert not torch.any(torch.isinf(ro))

                        Q_ret[:,t] = seg_rewards[:,t] + gamma*(ci*(Q_ret[:,t+1] - Q_t[:,t+1]) + V_t[:,t+1])
                        V_target[:, t+1] = ci*(Q_ret[:,t+1] - Q_t[:,t+1]) + V_t[:,t+1]

                    V_target[:, 0] = ci*(Q_ret[:,0] - Q_t[:,0]) + V_t[:,0]

                
                    assert not torch.any(torch.isnan(V_target))
                    assert not torch.any(torch.isnan(Q_ret))
                    assert not torch.any(torch.isinf(V_target))
                    assert not torch.any(torch.isinf(Q_ret))


                # run PPO for n epochs
                for ppo_epoch in range(experiment['PPO']):
                    for b_idx in range((seg_len*replay_batch_size+batch_size-1)//batch_size):

                        b_Q_ret = Q_ret[:,:-1].reshape((seg_len-1)*replay_batch_size)[b_idx*batch_size:(b_idx+1)*batch_size]
                        b_V_target = V_target[:,:-1].reshape((seg_len-1)*replay_batch_size)[b_idx*batch_size:(b_idx+1)*batch_size]
                        b_state = seg_state[:,:-1].reshape((seg_len-1)*replay_batch_size,-1)[b_idx*batch_size:(b_idx+1)*batch_size]
                        b_action = seg_actions[:,:-1].reshape((seg_len-1)*replay_batch_size,-1)[b_idx*batch_size:(b_idx+1)*batch_size]
                        b_prob = probs_t[:,:-1].reshape((seg_len-1)*replay_batch_size)[b_idx*batch_size:(b_idx+1)*batch_size]
                        b_prob_seg = seg_action_probs[:,:-1].reshape((seg_len-1)*replay_batch_size)[b_idx*batch_size:(b_idx+1)*batch_size]
                        b_corr = corr[:,:-1].reshape((seg_len-1)*replay_batch_size)[b_idx*batch_size:(b_idx+1)*batch_size]

                        actions, _, _ = actor(b_state)
                        dist = compute_dist(actions)

                        sampled_action_decoded = b_action.tanh()
                        _, Vs, hs = V(b_state)
                        _, Qs, _ = Q(torch.concat([hs, sampled_action_decoded],-1))

                        assert not torch.any(torch.isnan(actions))
                        assert not torch.any(torch.isnan(Vs))
                        assert not torch.any(torch.isnan(Qs))
                        assert not torch.any(torch.isinf(actions))
                        assert not torch.any(torch.isinf(Vs))
                        assert not torch.any(torch.isinf(Qs))

                        log_prob_sample = dist.log_prob(b_action)
                        log_prob = log_prob_sample - torch.log( 1- sampled_action_decoded.square() + 1e-8 ).sum(-1)
                        probs = log_prob.exp()

                        r = (probs + 1e-5 )/(b_prob + 1e-5 )
                        adv = (b_prob/b_prob_seg).clip(0,c)*(b_Q_ret - b_V_target)# + b_corr
                        L = torch.minimum(adv*r, adv*r.clip(1-eps,1+eps))

                        # approximation of the hentropy
                        if h_samples>0:
                            s = dist.sample((h_samples,))
                            log_prob = dist.log_prob(s) - torch.log(1 - s.tanh().square() + 1e-8).sum(-1)
                            H = -log_prob.exp().mean(0)
                        else:
                            H = -log_prob.exp()

                        expected_Q = (torch.softmax(Vs.detach()+Qs,1)@bin_values[:,None])[:,0]
                        L2 = torch.minimum(expected_Q*r, expected_Q*r.clip(1-eps,1+eps)) 

                        actor_error = - L - H*experiment['entropy']# - L2

                        y_V = two_hot_encode(b_V_target, simlog_max_range, simlog_res, simlog_half_res, smoothing=smoothing)
                        y_Q = two_hot_encode(b_Q_ret, simlog_max_range, simlog_res, simlog_half_res, smoothing=smoothing)
                        critic_error_V = torch.nn.functional.cross_entropy(Vs, y_V, reduction='none')
                        critic_error_Q = torch.nn.functional.cross_entropy(Vs+Qs, y_Q, reduction='none')
                        critic_error = (critic_error_V + critic_error_Q) / 2
                        error = actor_error + critic_error

                        assert not torch.any(torch.isnan(actor_error))
                        assert not torch.any(torch.isnan(critic_error))
                        assert not torch.any(torch.isinf(actor_error))
                        assert not torch.any(torch.isinf(critic_error))

                        optim_critic.zero_grad()
                        critic_error.mean().backward(retain_graph=True)
                        optim_critic.step()
                        optim_actor.zero_grad() # also delete the derivative of the critic with respect to the actor through Q
                        actor_error.mean().backward()
                        optim_actor.step()

                    update_target_model(model=actor, target_model=target_actor, decay=5e-3)
                    update_target_model(model=V, target_model=target_V, decay=5e-3)
                    update_target_model(model=Q, target_model=target_Q, decay=5e-3)


        states = new_states

        if (i+1) % validate_every == 0:

            Values = []
            A = []
            for b in range((len(st)+batch_size-1)//batch_size):
                stb = st[b*batch_size:(b+1)*batch_size]
                _, v, _ = V(stb.reshape(stb.shape[0],-1).cuda())
                a, _, _ = actor(stb.reshape(stb.shape[0],-1).cuda())
                Values.append(v)
                A.append(a)
            Values = torch.concat(Values,0)
            A = torch.concat(A,0)

            V_t = []
            for b in range((len(st)+batch_size-1)//batch_size):
                stb = st[b*batch_size:(b+1)*batch_size]

                _, v, _ = target_V(stb.reshape(stb.shape[0],-1).cuda())

                V_t.append(v)
            V_t = torch.concat(V_t,0)


            if use_symlog:
                Values = (Values.softmax(1)@bin_values[:,None])[:,0].detach().cpu()
                V_t = (V_t.softmax(1)@bin_values[:,None])[:,0].detach().cpu()
            else:
                V = Values.cpu()
                V_t = V_t.cpu()

            tb_writer.add_image('V', (Values.reshape(1,100,100)/2+0.5), i)
            tb_writer.add_image('V_t', V_t.reshape(1,100,100)/2+0.5, i)

            pos = st[:,0].reshape(100,100,-1)[::2,::2]
            A = A.reshape(100,100,-1)[::2,::2]

            plt.quiver(pos[...,1].flatten(), -pos[...,0].flatten(), A[...,1].tanh().detach().cpu().flatten(), -A[...,0].tanh().detach().cpu().flatten(), color='g',scale=50, headwidth=2)
            ax.axis('off')
            plt.gca().set_aspect('equal')
            plt.subplots_adjust(0,0,1,1,0,0)
            fig.canvas.draw()
            data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
            data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            plt.clf()

            tb_writer.add_image('Policy visualization', np.transpose(data,(2,0,1)) , i)
            



100%|██████████| 131072/131072 [55:11<00:00, 39.58it/s] 
100%|██████████| 131072/131072 [1:58:31<00:00, 18.43it/s] 


<Figure size 1000x1000 with 0 Axes>