In [1]:
%load_ext autoreload
%reload_ext autoreload
%autoreload 2

In [2]:
from gym_swimmer import SwimmerEnv
import torch
import numpy as np
from torch import nn
import math
from models import *
from core import generate_default_model_name
Env = SwimmerEnv

In [3]:
BATCH = 64
N_EPOCH = 12000
n_candidates = 100
bthreshold=1e-2
name_dict = generate_default_model_name(Env)
BMODEL_PATH = name_dict['db'].replace('dbgnn', 'dbnn')
LMODEL_PATH = name_dict['dl'].replace('dlgnn', 'dlnn')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

bnn = DMLP(state_dim=Env.state_dim, action_dim=Env.action_dim, mode='straight')
bnn.to(device)
bnn.train()

lnn = DMLP(state_dim=Env.state_dim+Env.goal_dim, action_dim=Env.action_dim, mode='sum')
lnn.to(device)
lnn.train()

boptimizer = torch.optim.Adam(bnn.parameters(), lr=1e-4, weight_decay=1e-8)
bscheduler = torch.optim.lr_scheduler.ExponentialLR(boptimizer, gamma=0.996)

loptimizer = torch.optim.Adam(lnn.parameters(), lr=1e-4, weight_decay=1e-8)
lscheduler = torch.optim.lr_scheduler.ExponentialLR(loptimizer, gamma=0.996)

In [4]:
def sample_action(nn, o, tensor_a, max_iter=30, mode='max', threshold=-1e-2):
    '''
        Laypunov: min
        Barrier: max
    '''
    # size of a: (num_agents, n_candidates, action_dim)
    
    if len(o.shape)==2:
        o = o.unsqueeze(1)
    assert len(tensor_a.shape)==3
    n_candidate = tensor_a.shape[1]
    
    nn.eval()
    
    vec = nn.get_vec(o).detach()
    vec = vec.repeat((1, n_candidate, 1))    
    
    tensor_a.requires_grad = True
    aoptimizer = torch.optim.Adam([tensor_a], lr=1)

    iter_ = 0
    while iter_ < max_iter:
        value = nn.get_field(vec, tensor_a)
        if mode=='max':
            cvalue = (-value+threshold).relu()
        else:
            cvalue = (value-threshold).relu()
        if torch.min(cvalue, dim=-1)[0].sum()==0:
            break
        aoptimizer.zero_grad()
        cvalue.sum().backward()
        torch.nn.utils.clip_grad_value_([tensor_a], 1e-2)
        aoptimizer.step()
        with torch.no_grad():
            tensor_a[:] = tensor_a.clamp(-1, 1)
        iter_ += 1
        
    value = nn.get_field(vec, tensor_a)
    if mode=='max':
        cvalue = (-value+threshold).relu()
    else:
        cvalue = (value-threshold).relu()    
    
    finalv = torch.zeros_like(value[:, 0])
    finala = torch.zeros_like(tensor_a[:, 0, :])
    valid = torch.min(cvalue, dim=-1)[0]==0
    if mode=='max':
        if (~valid).sum()!=0:
            finalv[~valid] = torch.max(value[~valid], dim=-1)[0]
            finala[~valid] = tensor_a[~valid, torch.max(value[~valid], dim=-1)[1]]
        if (valid).sum()!=0:
            tvalue = value.clone()
            tvalue[cvalue!=0] = float('inf')
            finalv[valid] = torch.min(tvalue[valid], dim=-1)[0]
            finala[valid] = tensor_a[valid, torch.min(tvalue[valid], dim=-1)[1]]
    else:
        if (~valid).sum()!=0:
            finalv[~valid] = torch.min(value[~valid], dim=-1)[0]
            finala[~valid] = tensor_a[~valid, torch.min(value[~valid], dim=-1)[1]]
        if (valid).sum()!=0:
            tvalue = value.clone()
            tvalue[cvalue!=0] = float('-inf')
            finalv[valid] = torch.max(tvalue[valid], dim=-1)[0] 
            finala[valid] = tensor_a[valid, torch.max(tvalue[valid], dim=-1)[1]]
    
    nn.train()
    
    return tensor_a, value, finalv, finala

In [52]:
def train_barrier(bnn, optimizer, buf, pbar, lamda=0.1, n_iter=10, n_candidates=1000, sample='cam'):
    bnn.train()
    buf.concat_goal = False
    
    # Set up function for computing value loss
    def compute_loss(bnn, data, next_data):
        value = bnn(**data)
        next_o = data['next_x']
        next_value = bnn(**next_data)
        
        bloss1 = ((1e-2-value).relu())*data['prev_free']*data['next_free'] / (1e-9 + (data['next_free']).sum())
        bloss2 = ((1e-2+value).relu())*(data['prev_danger']+data['next_danger']) / (1e-9 + (data['prev_danger']+data['next_danger']).sum())
        bloss = bloss1.sum() + bloss2.sum()
        
        deriv = next_value-value+0.1*value
        dloss = ((-deriv+1e-2).relu())*data['prev_free']*data['next_free']*next_data['next_free']
        dloss = dloss.sum() / (1e-9 + (data['prev_free']*data['next_free']*next_data['next_free']).sum())
        
        # if sample=='cam':
        #     a = torch.rand(len(next_o), n_candidates, data['action'].shape[-1]).to(device).uniform_(-1, 1)
        #     a[:, 0, :] = next_data['action']
        # else:
        a = torch.rand(len(next_o), n_candidates, data['action'].shape[-1]).to(device).uniform_(-1, 1)
        
        next_value_neg = bnn(x=next_o.unsqueeze(1).repeat(1, n_candidates, 1), action=a)
        deriv = next_value-value
        good = (deriv+1e-2).relu()
        deriv = next_value_neg-value.unsqueeze(-1)+0.1*value.unsqueeze(-1)
        good_noise = ((deriv+1e-2).relu())
        
        contrastloss = good_noise.mean()

        return bloss, dloss, contrastloss
    
    # imitation learning
    for i in range(n_iter):
        loader, next_loader = buf.get()
        for j, data_pair in enumerate(zip(loader, next_loader)):
            data, next_data = data_pair 
            optimizer.zero_grad()
            bloss, dloss, closs = compute_loss(bnn, data, next_data)
            loss = bloss + dloss + closs
            loss.backward()            
            optimizer.step() 
            with torch.no_grad():
                bvalue = bnn(**data)
                b_mean = bvalue.mean()
            desc = "bloss %.6f, dloss %.6f, closs %.6f, bmean %.6f" % (bloss, dloss, closs, b_mean)
            pbar.set_description(desc)
            optimizer.zero_grad()    
    
    return desc

In [57]:
def train_lyapunov(lnn, optimizer, buf, pbar, lamda=0.1, n_iter=10, n_candidates=1000, sample='cam'):
    assert sample=='uniform' or sample=='cam'
    
    lnn.train()
    buf.concat_goal = True
    
    # Set up function for computing value loss
    def compute_loss(lnn, data, next_data):
        value = lnn(**data).detach().reshape(len(data['next_x']))
        next_o = data['next_x']

#         if sample=='cam':
#             a = torch.rand(len(next_o), n_candidates, data['action'].shape[-1]).to(device).uniform_(-1, 1)
#             a[:, 0, :] = next_data['action']
#         else:
        a = torch.rand(len(next_o), n_candidates, data['action'].shape[-1]).to(device).uniform_(-1, 1)
        
        value = lnn(**data)
        next_value = lnn(**next_data)
        next_value_neg = lnn(x=next_o.unsqueeze(1).repeat(1, n_candidates, 1), action=a)        
        goal_loss = ((value**2)*data['next_goal']).sum() / (1e-9 + data['next_goal'].sum()) + \
                    ((next_value**2)*next_data['next_goal']).sum() / (1e-9 + next_data['next_goal'].sum())
        
        deriv = next_value-value
        good = (-deriv+1e-2).relu()
        bad = (deriv+1e-2).relu()
        badloss = bad.mean()
        
        deriv = next_value_neg-value.unsqueeze(-1)
        good_noise = ((-deriv+1e-2).relu())
        
        contrastloss = good_noise.mean()

        return goal_loss, badloss, contrastloss

    # imitation learning
    for i in range(n_iter):
        loader, next_loader = buf.get()
        for j, data_pair in enumerate(zip(loader, next_loader)):
            data, next_data = data_pair
            optimizer.zero_grad()
            goal_loss, dloss, contrastloss = compute_loss(lnn, data, next_data)
            loss = goal_loss + dloss + contrastloss
            loss.backward()            
            optimizer.step() 
            desc = "goal_loss %.6f, dloss %.6f, closs %.6f" % (goal_loss, dloss, contrastloss)
            pbar.set_description(desc)
            optimizer.zero_grad()    
    
    return desc

In [70]:
# create replay buffer
import scipy
from random import shuffle
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from collections import defaultdict


class DotDict(dict):
    """
    a dictionary that supports dot notation 
    as well as dictionary access notation 
    usage: d = DotDict() or d = DotDict({'val1':'first'})
    set attributes: d.val2 = 'second' or d['val2'] = 'second'
    get attributes: d.val2 or d['val2']
    """
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

    def __init__(self, dct):
        for key, value in dct.items():
            if hasattr(value, 'keys'):
                value = DotDict(value)
            self[key] = value
            
    def to(self, device):
        for key, value in self.items():
            self[key] = value.to(device)


class GlobalReplayBuffer:
    """
    A buffer for storing trajectories experienced by a PPO agent interacting
    with the environment, and using Generalized Advantage Estimation (GAE-Lambda)
    for calculating the advantages of state-action pairs.
    """

    def __init__(self, size):
        self.obs_buf = []  
        self.ptr = 0
        self.max_size = size        
        
    def store(self, **kwargs):
        """
        Append one timestep of agent-environment interaction to the buffer.
        """
#         assert self.ptr < self.max_size     # buffer has to have room so you can store
        
        obs = DotDict({})
        for key, value in kwargs.items():
            obs[key] = torch.as_tensor(value, dtype=torch.float)
        self.obs_buf.append(obs)
        self.ptr += 1

    def get(self, batch_size, concat_goal=False):
        """
        Call this at the end of an epoch to get all of the data from
        the buffer, with advantages appropriately normalized (shifted to have
        mean zero and std one). Also, resets some pointers in the buffer.
        """
        # collate_fn = lambda x: {x_.to(device) if print(x, x_) else 0 for x_ in default_collate(x)}
        
        def collate_fn(data):
            """
               data: is a list of tuples with (example, label, length)
                     where 'example' is a tensor of arbitrary shape
                     and label/length are scalars
            """
            data = default_collate(data)
            for k, v in data.items():
                data[k] = v.to(device)
            if concat_goal:
                data['x'] = torch.cat((data['x'], data['goal']), dim=-1)
            return data
            

        l = list(zip(self.obs_buf[:-1], self.obs_buf[1:]))
        shuffle(l)
        
        loader = DataLoader([_[0] for _ in l], shuffle=False, batch_size=batch_size, collate_fn=collate_fn)
        next_loader = DataLoader([_[1] for _ in l], shuffle=False, batch_size=batch_size, collate_fn=collate_fn)
        
        return loader, next_loader
    
    def relabel_l(self):
        if lbuf.obs_buf[-1]['next_goal']==1:
            return
        
        # choose a future state
        chosen_idx = np.random.randint(1, len(self.obs_buf))
        obs = self.obs_buf[chosen_idx]
        new_goal = obs['x'].data.cpu().numpy()[:len(obs['goal'])]
        for idx, obs, next_obs in zip(range(len(self.obs_buf)-1), self.obs_buf[:-1], self.obs_buf[1:]):
            obs['goal'] = torch.as_tensor(new_goal, dtype=torch.float)
            obs['next_goal'] = ((next_obs['x'][:len(new_goal)]-obs['goal']).norm() < 0.1)
            if obs['next_goal']:
                chosen_idx = idx + 1

        self.obs_buf = self.obs_buf[:chosen_idx]
    
    def relabel_b(self):
        # TODO
        pass
    
    
class GatherReplayBuffer:
    
    def __init__(self, batch=64, concat_goal=False):
        self.buffers = []
        self.batch = batch
        self.concat_goal = concat_goal
        
    def append(self, buffer, reward=1):
        buffer.reward = reward
        self.buffers.append(buffer)
        
    def get(self):
        # max_reward = np.max([b.reward for b in self.buffers])
        # mean_reward = np.mean([b.reward for b in self.buffers])
        # min_reward = np.min([b.reward for b in self.buffers])
        # for b in self.buffers:
        #     # reward_norm = (b.reward - mean_reward) / (max(max_reward-mean_reward,mean_reward-min_reward) + 1e-9) + 1
        #     for data in b.obs_buf:
        #         data['reward'] = torch.as_tensor(b.reward, dtype=torch.float)

        prev_o = []
        prev_o.extend([o for b in self.buffers for o in b.obs_buf[:-1]])
        next_o = []
        next_o.extend([o for b in self.buffers for o in b.obs_buf[1:]])
        
        # collate_fn = lambda x: {x_.to(device) if print(x, x_) else 0 for x_ in default_collate(x)}
        
        def collate_fn(data):
            """
               data: is a list of tuples with (example, label, length)
                     where 'example' is a tensor of arbitrary shape
                     and label/length are scalars
            """
            data = default_collate(data)
            for k, v in data.items():
                data[k] = v.to(device)
            if self.concat_goal:
                data['x'] = torch.cat((data['x'], data['goal']), dim=-1)
                data['next_x'] = torch.cat((data['next_x'], data['goal']), dim=-1)
            return data
            
        l = list(zip(prev_o, next_o))
        shuffle(l)
        
        loader = DataLoader([_[0] for _ in l], shuffle=False, batch_size=self.batch, collate_fn=collate_fn)
        next_loader = DataLoader([_[1] for _ in l], shuffle=False, batch_size=self.batch, collate_fn=collate_fn)
        
        return loader, next_loader        

In [38]:
a = torch.FloatTensor([1, 1, 0, 1])
b = torch.FloatTensor([0.1, 1, -0.2, 1])
b.requires_grad = True
(-(a + b.relu() - b.relu().detach())).sum().backward()
b.grad

tensor([-1., -1.,  0., -1.])

# Warm Up

In [71]:
from gym_swimmer import SwimmerEnv
from stable_baselines3 import PPO
from tqdm import tqdm 

env = SwimmerEnv()
model = PPO.load("swimmer/best_model.zip")

allbuf = GatherReplayBuffer(batch=BATCH)
n_collision = 0
n_reach = 0
for _ in tqdm(range(1000)):
    nowbuf = GlobalReplayBuffer(1024)
    obs = env.reset()
    is_collide = False
    returns = 0
    while True:
        obs = env._get_obs()
        ac = model.predict(obs)[0].clip(-0.99, 0.99)
        next_obs, rw, done, info = env.step(ac)
        returns += rw
        if info['next_danger']:
            is_collide = True
        nowbuf.store(**info)
        if done:
            break
    n_reach += info['next_goal']
    n_collision += is_collide
    if info['next_goal'] and (not is_collide):
        # print(returns)
        allbuf.append(nowbuf, reward=1)  # returns

print(n_reach, n_collision)

with tqdm() as pbar:
    descb = train_barrier(bnn, boptimizer, allbuf, pbar=pbar, n_iter=100)
    descl = train_lyapunov(lnn, loptimizer, allbuf, pbar=pbar, n_iter=100)

torch.save(bnn.state_dict(), BMODEL_PATH)
torch.save(lnn.state_dict(), LMODEL_PATH.replace('.pt', '_cam.pt'))

100%|██████████| 1000/1000 [00:40<00:00, 24.47it/s]


930 68


goal_loss 0.000000, dloss 0.008990, closs 0.000047: : 0it [18:00, ?it/s]            


In [24]:
%debug

> [0;32m<ipython-input-17-19d52174369c>[0m(127)[0;36mget[0;34m()[0m
[0;32m    125 [0;31m                [0mdata[0m[0;34m[[0m[0;34m'reward'[0m[0;34m][0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mas_tensor[0m[0;34m([0m[0mreward_norm[0m[0;34m,[0m [0mdtype[0m[0;34m=[0m[0mtorch[0m[0;34m.[0m[0mfloat[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    126 [0;31m[0;34m[0m[0m
[0m[0;32m--> 127 [0;31m        [0;32massert[0m [0;32mFalse[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    128 [0;31m[0;34m[0m[0m
[0m[0;32m    129 [0;31m        [0mprev_o[0m [0;34m=[0m [0;34m[[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  min_reward


2.6069009180234284


ipdb>  max_reward


3.103223548748725


ipdb>  exit()


In [63]:
bnn = DMLP(state_dim=Env.state_dim, action_dim=Env.action_dim, mode='straight')
bnn.to(device)
bnn.train()

lnn = DMLP(state_dim=Env.state_dim+Env.goal_dim, action_dim=Env.action_dim, mode='sum')
lnn.to(device)
lnn.train()

boptimizer = torch.optim.Adam(bnn.parameters(), lr=1e-4, weight_decay=1e-8)
bscheduler = torch.optim.lr_scheduler.ExponentialLR(boptimizer, gamma=0.996)

loptimizer = torch.optim.Adam(lnn.parameters(), lr=1e-4, weight_decay=1e-8)
lscheduler = torch.optim.lr_scheduler.ExponentialLR(loptimizer, gamma=0.996)

with tqdm() as pbar:
    descb = train_barrier(bnn, boptimizer, allbuf, pbar=pbar, n_iter=10)
    descl = train_lyapunov(lnn, loptimizer, allbuf, pbar=pbar, n_iter=100, sample='uniform')

torch.save(bnn.state_dict(), BMODEL_PATH)
torch.save(lnn.state_dict(), LMODEL_PATH.replace('.pt', '_uniform.pt'))

goal_loss 0.000227, dloss 0.001915, closs 0.000016: : 0it [03:34, ?it/s]


In [50]:
%debug

> [0;32m/home/rainorangelemon/anaconda3/envs/gnn/lib/python3.8/site-packages/torch/nn/functional.py[0m(1753)[0;36mlinear[0;34m()[0m
[0;32m   1751 [0;31m    [0;32mif[0m [0mhas_torch_function_variadic[0m[0;34m([0m[0minput[0m[0;34m,[0m [0mweight[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1752 [0;31m        [0;32mreturn[0m [0mhandle_torch_function[0m[0;34m([0m[0mlinear[0m[0;34m,[0m [0;34m([0m[0minput[0m[0;34m,[0m [0mweight[0m[0;34m)[0m[0;34m,[0m [0minput[0m[0;34m,[0m [0mweight[0m[0;34m,[0m [0mbias[0m[0;34m=[0m[0mbias[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1753 [0;31m    [0;32mreturn[0m [0mtorch[0m[0;34m.[0m[0m_C[0m[0;34m.[0m[0m_nn[0m[0;34m.[0m[0mlinear[0m[0;34m([0m[0minput[0m[0;34m,[0m [0mweight[0m[0;34m,[0m [0mbias[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1754 [0;31m[0;34m[0m[0m
[0m[0;32m   1755 [0;31m[0;34m[0m[0m
[0m


ipdb>  exit()


# Train Function

In [12]:
def iter_action(bnn, lnn, o_b, o_l, a, bthreshold=-1e-2, lthreshold=-1e-2, max_iter=30):
    # size of a: (num_agents, n_candidates, action_dim)
    
    a = a.reshape((-1, a.shape[-1]))
    n_candidate = a.shape[0]
    
    bnn.eval()
    lnn.eval()
    
    input_b = {k: v.to(device) for k, v in o_b.items()}
    vecb = bnn.get_vec(**(input_b)).detach()
    vecb = vecb.reshape(1, -1).repeat((n_candidate, 1))
    
    input_l = {k: v.to(device) for k, v in o_l.items()}
    vecl = lnn.get_vec(x=torch.cat((input_l['x'], input_l['goal']), dim=-1)).detach()
    vecl = vecl.reshape(1, -1).repeat((n_candidate, 1))    
    
    tensor_a = torch.FloatTensor(a).to(device)
    tensor_a.requires_grad = True
    aoptimizer = torch.optim.Adam([tensor_a], lr=1)

    iter_ = 0
    while iter_ < max_iter:
        bvalue = bnn.get_field(vecb, tensor_a)
        lvalue = lnn.get_field(vecl, tensor_a)
        cvalue = (-bvalue+bthreshold).relu()+(lvalue-lthreshold).relu()
        if torch.min(cvalue)==0:
            break
        aoptimizer.zero_grad()
        cvalue.sum().backward()
        torch.nn.utils.clip_grad_value_([tensor_a], 1e-2)
        aoptimizer.step()
        with torch.no_grad():
            tensor_a[:] = tensor_a.clamp(-1, 1)
        iter_ += 1

    bvalue = bnn.get_field(vecb, tensor_a)
    lvalue = lnn.get_field(vecl, tensor_a)
    cvalue = (-bvalue+bthreshold).relu()+(lvalue-lthreshold).relu()
    return tensor_a.data.cpu().numpy(), bvalue.data.cpu().numpy(), lvalue.data.cpu().numpy(), cvalue.data.cpu().numpy()

def choose_action(cvalue):
    if np.any(cvalue == 0):
        idx = np.arange(len(cvalue))[cvalue == 0]
        idx = np.random.choice(idx, 1)[0]
    else:
        idx = np.argmin(cvalue)
    return idx

In [72]:
from tqdm import tqdm
import gc
from copy import deepcopy

# def is_counter_d(o, next_o, free, next_free, barrier, v_cur, v_next):
#     counter_mse = np.abs(v_next - v_cur - barrier) > 1e-2
#     return counter_mse

# def is_counter_b(o, next_o, free, danger, barrier, v_cur, v_next):
#     not_free = next_free.astype(float)<free.astype(float)
#     counter_free = np.logical_and(free, v_cur > -1e-1)
#     counter_obs = np.logical_and(danger, v_cur < 1e-1)
#     counter_barrier = np.logical_and(v_next - v_cur > -0.1 * v_cur, free)
#     return np.logical_or(np.logical_or(counter_free, counter_obs), counter_barrier)
    

max_episode_length     = Env.max_episode_steps
EXPERIENCE_BUFFER_SIZE = Env.max_episode_steps

LOG_FILE_L = 'cam_'+Env.__name__+'_l.txt'
LOG_FILE_B = 'cam_'+Env.__name__+'_b.txt'
open(LOG_FILE_L, 'w+').close()
open(LOG_FILE_B, 'w+').close()
bbuf = GlobalReplayBuffer(EXPERIENCE_BUFFER_SIZE)
lbuf = GlobalReplayBuffer(EXPERIENCE_BUFFER_SIZE)
env = Env()
env.reset(); lthreshold=-1000.; bthreshold=1000; nowbuf = GlobalReplayBuffer(1024); is_collide=False; returns=0
o = env._get_obs()

from gym_swimmer import SwimmerEnv
from stable_baselines3 import PPO
from tqdm import tqdm 

pbar = tqdm(range(N_EPOCH))
for epoch_i in pbar:
    
    total_trans = 0
    unsafe_rate = 0
#     buf.max_size += EXPERIENCE_BUFFER_SIZE
    # Main loop: collect experience in env and update/log each epoch
    while True:

        o = env._get_obs()
        a_all = np.random.uniform(-1., 1., size=(n_candidates, env.action_dim))
        o_l = o_b = {'x': torch.FloatTensor(o), 'goal': torch.FloatTensor(env.goal)}
        a_refine, bvalue, lvalue, cvalue = iter_action(bnn, lnn, o_b, o_l, a_all, max_iter=min(epoch_i//100, 30), lthreshold=lthreshold, bthreshold=bthreshold)
        idx = choose_action(cvalue)
        a, bvalue, lvalue, cvalue = a_refine[idx, :], bvalue[idx], lvalue[idx], cvalue[idx]
        # lthreshold = lvalue-1e-2
        # bthreshold = 0.9 * bvalue + 1e-2

        next_o, r, d, info = env.step(a)
        returns += r
        if info['next_danger']:
            is_collide = True
        
        nowbuf.store(**info)
        
        total_trans += 1
        free = np.array(r)

        if d:
            if info['next_goal'] and not is_collide:
                allbuf.append(nowbuf, reward=returns)
            env.reset(); lthreshold=-1000.; bthreshold=1000; nowbuf = GlobalReplayBuffer(1024); is_collide=False; returns=0
            break

    unsafe_rate = unsafe_rate / total_trans
    
    if (epoch_i % 1000) == 999:
        # MAYBE NEED TO RELABEL
        descb = train_barrier(bnn, boptimizer, allbuf, pbar=pbar, n_iter=100)
        descl = train_lyapunov(lnn, loptimizer, allbuf, pbar=pbar, n_iter=100)

        # if (epoch_i % 10 == 0) and (epoch_i != 0) and (epoch_i < 6000):
        #     bscheduler.step()
        #     lscheduler.step()

        with open(LOG_FILE_L, 'a+') as f:
            f.write(descl+'\t'+str(pbar.last_print_n)+'\n')
        with open(LOG_FILE_B, 'a+') as f:
            f.write(descb+'\t'+str(pbar.last_print_n)+'\t'+'unsafe rate: '+str(unsafe_rate)+'\n')     

        torch.save(bnn.state_dict(), BMODEL_PATH)
        torch.save(lnn.state_dict(), LMODEL_PATH)
        
        break

goal_loss 0.000033, dloss 0.004378, closs 0.000255:   8%|▊         | 999/12000 [20:08<3:41:52,  1.21s/it]            


KeyboardInterrupt: 

In [97]:
%debug

> [0;32m<ipython-input-60-d00378840270>[0m(14)[0;36mcompute_loss[0;34m()[0m
[0;32m     12 [0;31m        [0;32mif[0m [0msample[0m[0;34m==[0m[0;34m'cam'[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     13 [0;31m            [0ma[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mrand[0m[0;34m([0m[0mlen[0m[0;34m([0m[0mnext_o[0m[0;34m)[0m[0;34m,[0m [0mn_candidates[0m[0;34m,[0m [0mdata[0m[0;34m[[0m[0;34m'action'[0m[0;34m][0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;34m-[0m[0;36m1[0m[0;34m][0m[0;34m)[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mdevice[0m[0;34m)[0m[0;34m.[0m[0muniform_[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m,[0m [0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 14 [0;31m            [0m_[0m[0;34m,[0m [0m_[0m[0;34m,[0m [0m_[0m[0;34m,[0m [0mfinala[0m [0;34m=[0m [0msample_action[0m[0;34m([0m[0mlnn[0m[0;34m,[0m [0mnext_o[0m[0;34m,[0m [0ma[0m[0;34m,[0m [0mmax_iter[0m[0;34

ipdb>  value.shape


torch.Size([])


ipdb>  data['x']


tensor([[-0.4125, -0.3577, -1.3060,  1.4942,  1.4763, -1.4690, -0.5701, -1.7017,
          0.8365,  2.3944,  4.0000,  0.0000]], device='cuda:0')


ipdb>  next_o


tensor([[ 0.4271, -0.0196, -0.0940, -0.4063,  0.9484,  0.6745,  1.5356,  2.6297,
         -2.9808, -2.4587,  4.0000,  0.0000]], device='cuda:0')


ipdb>  value


tensor(0.1808, device='cuda:0')


ipdb>  (value-1e-2).unsqueeze(1)


*** IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)


ipdb>  (value.reshape(len(next_o), 1)-1e-2).unsqueeze(1)


tensor([[[0.1708]]], device='cuda:0')


ipdb>  value.reshape(len(next_o))


tensor([0.1808], device='cuda:0')


ipdb>  exit()


# Inference

In [79]:
def eval_performance(n_traj, gif=None):
    n_collision = 0
    n_reach = 0
    for _ in tqdm(range(n_traj)):
        is_collide = False
        env.reset()
        lthreshold = -1000.
        bthreshold = 1000.
        
        if (_==0) and (gif is not None):
            imgs = [env.sim.render(600, 300)]
        
        while True:
            o = env._get_obs()
            a_all = np.random.uniform(-1., 1., size=(1000, env.action_dim))

            o_l = o_b = {'x': torch.FloatTensor(o), 'goal': torch.FloatTensor(env.goal)}
            a_refine, bvalue, lvalue, cvalue = iter_action(bnn, lnn, o_b, o_l, a_all, max_iter=2, lthreshold=lthreshold, bthreshold=bthreshold)
            idx = choose_action(cvalue)

            ac, bvalue, lvalue, cvalue = a_refine[idx, :], bvalue[idx], lvalue[idx], cvalue[idx]
            # lthreshold = lvalue-1e-2            
            # bthreshold = max(0.9*bvalue+1e-2, 1e-2)
            
            next_obs, rw, done, info = env.step(ac)
            
            if (_==0) and (gif is not None):
                imgs.append(env.sim.render(600, 300))
            
            if info['next_danger']:
                is_collide = True
            if done:
                break
        n_reach += info['next_goal']
        n_collision += is_collide
    print('total trajs:'+str(n_traj)+', goal reached: '+str(n_reach)+', collision: '+str(n_collision))
    
    if gif is not None:
        from PIL import Image
        ims = [Image.fromarray(np.flip(a_frame, axis=0)) for a_frame in imgs]
        ims[0].save(gif, save_all=True, append_images=ims[1:], duration=100)
        
    return n_traj, n_reach, n_collision

In [80]:
eval_performance(1000, gif='lya_2phase.gif')

100%|██████████| 1000/1000 [02:25<00:00,  6.87it/s]


total trajs:1000, goal reached: 926, collision: 440


(1000, 926, 440)

In [77]:
while True:
    _, n_reach, n_collision = eval_performance(1, gif='lya_2phase.gif')
    if n_reach==0:
        break

100%|██████████| 1/1 [00:00<00:00,  9.34it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  6.00it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.62it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 13.39it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.68it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  8.20it/s]


total trajs:1, goal reached: 1, collision: 1


100%|██████████| 1/1 [00:00<00:00, 12.60it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 11.28it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  5.58it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  2.36it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.31it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.08it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.67it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.72it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  6.40it/s]


total trajs:1, goal reached: 1, collision: 1


100%|██████████| 1/1 [00:00<00:00, 11.35it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 10.82it/s]


total trajs:1, goal reached: 1, collision: 1


100%|██████████| 1/1 [00:00<00:00,  3.59it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  3.50it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 11.41it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  7.18it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 11.35it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  5.69it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.09it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 11.39it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.77it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  3.98it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  9.42it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 11.92it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.70it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  7.25it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.76it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.63it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.15it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  5.29it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  5.44it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.00it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  8.92it/s]


total trajs:1, goal reached: 1, collision: 1


100%|██████████| 1/1 [00:00<00:00, 12.58it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 11.19it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  6.24it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 11.32it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.70it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.73it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.80it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  3.73it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 10.71it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  6.45it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.06it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  2.57it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  8.96it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.03it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  6.21it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 10.66it/s]


total trajs:1, goal reached: 1, collision: 1


100%|██████████| 1/1 [00:00<00:00,  3.72it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 10.09it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  5.28it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  6.82it/s]


total trajs:1, goal reached: 1, collision: 1


100%|██████████| 1/1 [00:00<00:00,  9.08it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.67it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.58it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 13.42it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  5.83it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 11.96it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  6.81it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.06it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.67it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.09it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 10.75it/s]


total trajs:1, goal reached: 1, collision: 1


100%|██████████| 1/1 [00:00<00:00, 12.30it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  9.11it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  7.17it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.67it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.05it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.53it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.62it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 13.42it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 11.42it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  7.11it/s]


total trajs:1, goal reached: 1, collision: 1


100%|██████████| 1/1 [00:00<00:00,  6.01it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  9.44it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 11.97it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  6.96it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  6.94it/s]


total trajs:1, goal reached: 1, collision: 1


100%|██████████| 1/1 [00:00<00:00, 13.39it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.00it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.57it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  5.27it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.68it/s]


total trajs:1, goal reached: 1, collision: 1


100%|██████████| 1/1 [00:00<00:00,  8.67it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.01it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  9.42it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  7.72it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.02it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  4.66it/s]


total trajs:1, goal reached: 1, collision: 1


100%|██████████| 1/1 [00:00<00:00, 12.02it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  9.37it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.69it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  9.39it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 11.20it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  5.97it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  5.58it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.64it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.07it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  6.34it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 11.15it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 11.98it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.03it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 11.33it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  3.08it/s]


total trajs:1, goal reached: 1, collision: 1


100%|██████████| 1/1 [00:00<00:00, 12.62it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.65it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.68it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  4.73it/s]


total trajs:1, goal reached: 1, collision: 1


100%|██████████| 1/1 [00:00<00:00,  6.51it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  9.40it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  5.69it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  6.96it/s]


total trajs:1, goal reached: 1, collision: 1


100%|██████████| 1/1 [00:00<00:00, 12.68it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.53it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  6.69it/s]


total trajs:1, goal reached: 1, collision: 1


100%|██████████| 1/1 [00:00<00:00, 12.04it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.07it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.78it/s]


total trajs:1, goal reached: 1, collision: 1


100%|██████████| 1/1 [00:00<00:00,  7.69it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 11.93it/s]


total trajs:1, goal reached: 1, collision: 1


100%|██████████| 1/1 [00:00<00:00, 11.41it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 10.64it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.11it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 11.11it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  9.37it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.63it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  9.14it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 11.38it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 11.30it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 12.74it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00, 11.42it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  4.29it/s]


total trajs:1, goal reached: 1, collision: 0


100%|██████████| 1/1 [00:00<00:00,  2.16it/s]


total trajs:1, goal reached: 0, collision: 0


In [18]:
eval_performance(1000, gif='lya_2phase.gif')

100%|██████████| 1000/1000 [01:31<00:00, 10.89it/s]

total trajs:1000, goal reached: 796, collision: 267





In [None]:
lnn.load_state_dict(torch.load(LMODEL_PATH.replace('.pt', '_uniform.pt'), map_location=device))
eval_performance(1000, gif='lya_uniform.gif')

100%|██████████| 1000/1000 [02:02<00:00,  8.19it/s]


total trajs:1000, goal reached: 580, collision: 609


In [None]:
lnn.load_state_dict(torch.load(LMODEL_PATH.replace('.pt', '_cam.pt'), map_location=device))
eval_performance(1000, gif='lya_cam.gif')

100%|██████████| 1000/1000 [01:43<00:00,  9.70it/s]

total trajs:1000, goal reached: 673, collision: 508





In [59]:
imgs = [env.sim.render(600, 300)]

num_tot = 0
num_goaled = 0 
num_collision = 0

obs = env.reset(); lthreshold=-1000.
ts = 0
while True:
    ts += 1 
    
    o = env._get_obs()
    a_oracle = model.predict(o)[0]
    a_all = np.random.uniform(-1., 1., size=(1000, env.action_dim))
    a_all[0,:] = a_oracle
    
    o_l = o_b = {'x': torch.FloatTensor(o), 'goal': torch.FloatTensor(env.goal)}
    a_refine, bvalue, lvalue, cvalue = iter_action(bnn, lnn, o_b, o_l, a_all, max_iter=0, lthreshold=lthreshold, bthreshold=bthreshold)
    idx = choose_action(cvalue)

    idx = 0
    
    ac, bvalue, lvalue, cvalue = a_refine[idx, :], bvalue[idx], lvalue[idx], cvalue[idx]
    lthreshold = lvalue-1e-2
    
    print(ac, lvalue, bvalue, cvalue)
    obs, rw, done, _ = env.step(ac)
    if env.sim.data.ncon!=0:
        print('collision')
    imgs.append(env.sim.render(600, 300))
    if done:
        break

[-1.         0.5831593] 0.42128778 0.014337461 1000.42126
[-1. -1.] 0.3750703 0.028684039 0.0
[1. 1.] 0.35541138 0.037571814 0.0
[1. 1.] 0.32701087 0.045917545 0.0
[-1. -1.] 0.30607203 0.055527944 0.0
[-0.55894417 -1.        ] 0.2863273 0.05855595 0.0
[1. 1.] 0.2701758 0.06429812 0.0
[0.26076806 1.        ] 0.24688685 0.06875621 0.0
[-1. -1.] 0.23051739 0.07218124 0.0
[ 0.31480977 -0.3714502 ] 0.20738778 0.0725701 0.0
[0.37922457 1.        ] 0.19172908 0.07625434 0.0
[-1. -1.] 0.17480066 0.076287225 0.0
[ 0.32679668 -0.488843  ] 0.15865293 0.07306203 0.0
[0.6698414 1.       ] 0.14260043 0.070598036 0.0
[-1. -1.] 0.12231111 0.0795722 0.0
[ 0.55887526 -0.2996163 ] 0.10721043 0.06529434 0.0
[-0.04418565  1.        ] 0.08586976 0.06941733 0.0
[-0.56769294 -1.        ] 0.061580345 0.029957104 0.0
[1.        0.5479828] 0.052531846 0.032769207 0.0009515025
[-1.         0.5602556] 0.040497214 0.019314587 0.0
[-0.70100075 -0.8162678 ] 0.030907454 0.0030950494 0.0073151905
[1.         0.10529279

In [54]:
_['next_goal']

True

In [47]:
from PIL import Image
ims = [Image.fromarray(np.flip(a_frame, axis=0)) for a_frame in imgs]
ims[0].save("cam.gif", save_all=True, append_images=ims[1:], duration=20)

In [44]:
np.random.uniform(low=Env.state_range[0, :], high=Env.state_range[1, :], size=(10,))

array([ 5.42882926,  1.37825172, -1.33038784,  0.93033739,  0.60246979,
        0.94529349,  1.50789303, -2.97359193, -1.73358318, -5.0063304 ])