In [None]:
import numpy as np
import torch
from environment import KukaEnv, MazeEnv, SnakeEnv
from environment import Kuka2Env
from torch_geometric.nn import knn_graph
from collections import defaultdict
from time import time
import pickle
from tqdm import tqdm
from torch_sparse import coalesce

INFINITY = float('inf')


def construct_graph(env, points, check_collision=True):
    edge_index = knn_graph(torch.FloatTensor(points), k=5, loop=True)
    edge_index = torch.cat((edge_index, edge_index.flip(0)), dim=-1)
    edge_index_torch, _ = coalesce(edge_index, None, len(points), len(points))
    edge_index = edge_index_torch.data.cpu().numpy().T
    edge_cost = defaultdict(list)
    edge_free = []
    neighbors = defaultdict(list)
    for i, edge in enumerate(edge_index):
        if env._edge_fp(points[edge[0]], points[edge[1]]):
            edge_cost[edge[1]].append(np.linalg.norm(points[edge[1]]-points[edge[0]]))
            edge_free.append(True)
        else:
            edge_cost[edge[1]].append(INFINITY)
            edge_free.append(False)
        neighbors[edge[1]].append(edge[0])
    return edge_cost, neighbors, edge_index, edge_free


def min_dist(q, dist):
    """
    Returns the node with the smallest distance in q.
    Implemented to keep the main algorithm clean.
    """
    min_node = None
    for node in q:
        if min_node is None:
            min_node = node
        elif dist[node] < dist[min_node]:
            min_node = node

    return min_node


def dijkstra(nodes, edges, costs, source):
    q = set()
    dist = {}
    prev = {}

    for v in nodes:       # initialization
        dist[v] = INFINITY      # unknown distance from source to v
        prev[v] = INFINITY      # previous node in optimal path from source
        q.add(v)                # all nodes initially in q (unvisited nodes)

    # distance from source to source
    dist[source] = 0
    prev[source] = source

    while q:
        # node with the least distance selected first
        u = min_dist(q, dist)

        q.remove(u)

        for index, v in enumerate(edges[u]):
            alt = dist[u] + costs[u][index]
            if alt < dist[v]:
                # a shorter path to v has been found
                dist[v] = alt
                prev[v] = u

    return dist, prev


data = []
init_states = []
goal_states = []
maps = []
n_sample = [50, 200, 1000]
# env = MazeEnv(dim=2,  map_file="maze_files/mazes_4000.npz")
env = SnakeEnv(map_file='maze_files/snakes_15_2_3000.npz')
with np.load('maze_files/mazes_100000.npz') as f:
    env.maps = 1-f['arr_0']
# env = KukaEnv(kuka_file="kuka_iiwa/model_3.urdf", map_file="maze_files/kukas_13_3000.pkl")
# env = Kuka2Env()

time0 = time()

# for n in n_sample:\
pbar = tqdm(range(100000))
for problem_index in pbar:

    env.init_new_problem(problem_index)
    points = env.uniform_sample(n=500)
    edge_cost, neighbors, edge_index, edge_free = construct_graph(env, points)

    for source_index in range(len(points)):
        dist, prev = dijkstra(list(range(len(points))), neighbors, edge_cost, source_index)
        valid_goal = np.logical_and(np.array(list(dist.values())) != INFINITY, np.array(list(dist.values()))!=0)
        if valid_goal.sum() == 0:
            continue
        else:
            goal_index = np.array(list(dist.values()))[valid_goal].argmax()
            goal_index = np.arange(len(dist))[valid_goal][goal_index]
            init_states.append(points[source_index])
            goal_states.append(points[goal_index])
            maps.append(env.maps[problem_index])
            data.append((points, neighbors, edge_cost, edge_index, edge_free))
            break

    if len(maps) == 3000:
        break

    pbar.set_description(str(len(maps)))

    #
    # print(time()-time0)
    # print('yes')

with open('data/pkl/snake_prm_3000.pkl', 'wb') as f:
    pickle.dump(data, f, pickle.DEFAULT_PROTOCOL)

a = {'maps': maps,
     'init_states': init_states,
     'goal_states': goal_states, }
np.savez('maze_files/snakes_15_2_3000.npz', **a)

# with np.load('maze_files/snakes_15_2_3000.npz') as f:
#     maps = f['maps']
#     init_states = f['init_states']
#     goal_states = f['goal_states']
#
print('yes')

# train smoother v2

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING']='1'
%set_env CUDA_LAUNCH_BLOCKING=1

In [None]:
from train_smoother import *

from train_explorer import train_explorer
from importlib import reload
import train_smoother
reload(train_smoother)
from train_smoother import train_smoother
import torch
from environment import MazeEnv, KukaEnv, SnakeEnv, UR5Env, Kuka2Env
from str2name import str2name
from copy import deepcopy

str_ = 'kuka13'
epoch = 2000
env, model_explore, model_explore_path, model_smooth, model_smooth_path = str2name(str=str_)
model_explore.load_state_dict(torch.load(model_explore_path, map_location=torch.device("cpu")))
model_explore.to(device)
model_smooth_path = model_smooth_path.replace('.pt', 'v2.pt')
model = model_smooth
model_path = model_smooth_path
# model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.to(device)
writer = SummaryWriter()
INFINITY = float('inf')

In [None]:
from smoother import interpolate_path
# env = KukaEnv(kuka_file="kuka_iiwa/model_3.urdf", map_file="maze_files/kukas_13_3000.pkl")
set_random_seed(1234)

replay = []
for iter_i in range(3):

    indexes = np.random.permutation(epoch)
    pbar = tqdm(indexes)

    for index in pbar:
        env.init_new_problem(index)
        if iter_i != 0:
            env.set_random_init_goal()

        try:
            path, free, collided = explore(env, model_explore, model, smooth=False)
            if len(path) > 2:
                path_smooth = joint_smoother_ratio([tuple(node) for node in path], env, iter=5)
                replay.append((index, path, path_smooth, deepcopy(env.obstacles), free, collided))
                
        except Exception as e:
            continue

In [None]:
import pickle
pickle.dump([(r[0], r[1], r[2]) for r in replay], open("data/oracle_{0:s}.p".format(str_), "wb"))

In [None]:
import pickle
with open("data/oracle_{0:s}.p".format(str_), "rb") as f:
    replay = pickle.load(f)

extra_data = []
for index in tqdm(range(epoch)):
    env.init_new_problem(index)
    free, collided = env.sample_n_points(500, need_negative=True)
    extra_data.append((deepcopy(env.obstacles), free, collided))

for i, r in enumerate(replay):
    replay[i] = list(replay[i])
    replay[i].extend(extra_data[replay[i][0]])

In [None]:
extra_data[replay[i][0]]

In [None]:
len(replay)

In [None]:
from model_smoother import ModelSmoother
from train_smoother import *
model = ModelSmoother(workspace_size=env.dim, config_size=env.config_dim, embed_size=128, obs_size=6, scale=1).to(device)
_ = model.to(device)

In [None]:
np.array(env.bound).reshape(2, -1)[1,:]

In [None]:
model.load_state_dict(torch.load('data/weights/'))

In [None]:
model.train()
model_path = model_path.replace('v2.pt', 'v3.pt')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=0)
optimizer.zero_grad()

loss_min = float('inf')

for iter_i in range(20):

    indexes = np.random.permutation(len(replay))
    pbar = tqdm(np.arange(len(replay)))
    losses = []

    for index in pbar:

        if index % 8 != 0:
            continue

        loss = train(env, replay, model, optimizer, batch_idx=indexes[index:(index+8)])

        losses.append(float(loss))

        pbar.set_description("loss: %.5f" % np.mean(losses))
        writer.add_scalar('loss', loss)

    scheduler.step(np.mean(losses))
    
    if np.mean(losses) < loss_min:
        loss_min = np.mean(losses)
        torch.save(model.state_dict(), model_path)        

In [None]:
%debug

In [None]:
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))

In [None]:
model.eval()
env_id, path_origin, path_smooth, obstacles, free, collided = replay[0]
data = obs_data(model.config_size, obstacles, free, collided)
data = DotDict({k: torch.FloatTensor(v).to(device) for k, v in data.items()})
data.path = torch.FloatTensor(path_origin).to(device)
data.edge_index = torch.cat((torch.arange(1, len(path_origin)).reshape(1, -1),
               
                             torch.arange(0, len(path_origin)-1).reshape(1, -1)), dim=0)
data.edge_index = torch.cat((data.edge_index, data.edge_index.flip(0)), dim=-1)
data.edge_index, _ = add_self_loops(data.edge_index, num_nodes=len(data.path))
data.edge_index = data.edge_index.to(device)

In [None]:
from eval_gnn import path_cost

path_pred = model(**data, loop=10)
path_pred = path_pred.data.cpu().numpy()
path_pred[0]=path_smooth[0]
path_pred[-1]=path_smooth[-1]
print(path_cost(np.array(path_smooth)), path_cost(path_pred), path_cost(np.array(path_origin)))

In [None]:
path_pred

In [None]:
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))

In [None]:
c0, cost, data, explored, forward, success, t0, value, path, smooth_path = explore(env, model_explore, model, smooth=True, batch=50, t_max=1000)

In [None]:
env.collision_check_count-c0

In [None]:
env.collision_check_count-c0

In [None]:
c1 = int(env.collision_check_count)

In [None]:
from algorithm.bit_star import BITStar
bit = BITStar(env, batch_size=50, T=1000, sampling=None)
solution = bit.plan(INFINITY, time_budget=300, refine_time_budget=0)
path_cost(bit.get_best_path())

In [None]:
env.collision_check_count-c1

In [None]:
env.collision_check_count-c1

In [None]:
from eval_gnn import path_cost
print(path_cost(path))
print(path_cost(smooth_path))

In [None]:
import matplotlib.pyplot as plt
x = value.cpu().data.numpy()
n, bins, patches = plt.hist(x, 50, density=True, facecolor='g', alpha=0.75)

plt.grid(True)
plt.show()

In [None]:
model.load_state_dict(torch.load('data/weights/smooth_14d_attv3.pt', map_location=torch.device("cpu")))
model = model.to(device)

In [None]:
result_total = {}

In [None]:
from environment import MazeEnv, UR5Env, SnakeEnv, KukaEnv, Kuka2Env
import torch
from train_next import str2next
from str2name import str2name
from algorithm.tsa import NEXT_plan
from config import set_random_seed
from eval_gnn import explore
import numpy as np
from tqdm import tqdm
from eval_gnn import path_cost

set_random_seed(1234)
env = SnakeEnv(GUI=False)

UCB_type = 'kde'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cuda = True if torch.cuda.is_available() else False
model, model_path = str2next(str(env), env)
model.net.load_state_dict(torch.load(model_path, map_location=device))
results = []

for i in tqdm(range(3000)):

    pb = env.init_new_problem(i)
    model.set_problem(env.get_problem())
    set_random_seed(1234)
    solution = NEXT_plan(
        env=env,
        model=model,
        T=1000,
        g_explore_eps=0.1,
        stop_when_success=True,
        UCB_type=UCB_type
    )
    if solution[1]:
        results.append((i, solution))

In [None]:
from eval_gnn import path_cost
path_cost(results[1][1][0].path()[0])

In [None]:
env.init_new_problem(2085)

In [None]:
env.path

In [None]:
def str2env(str):
    if str == 'maze2easy':
        env = MazeEnv(dim=2)
        indexes = np.arange(2000, 3000)
        
    elif str == 'maze2hard':
        env = MazeEnv(dim=2, map_file='maze_files/mazes_hard.npz')
        indexes = np.arange(1000)        

    elif str == 'kuka7':
        env = KukaEnv()
        indexes = np.arange(2000, 3000)
        
    elif str == 'ur5':
        env = UR5Env()
        indexes = np.arange(2000, 3000)

    elif str == 'snake7':
        env = SnakeEnv(map_file='maze_files/snakes_15_2_3000.npz')
        indexes = np.arange(2000, 3000)

    elif str == 'kuka13':
        env = KukaEnv(kuka_file="kuka_iiwa/model_3.urdf", map_file="maze_files/kukas_13_3000.pkl")
        indexes = np.arange(2000, 3000)

    elif str == 'kuka14':
        env = Kuka2Env()
        indexes = np.arange(2000, 3000)

    return env, indexes

In [None]:
import pickle
result_total = pickle.load(open("data/new_smoother.p", "rb"))

In [None]:
import pickle
from importlib import reload
import eval_gnn
reload(eval_gnn)
from eval_gnn import eval_gnn, eval_gnn_pure
from eval_next import eval_next
from eval_bit import eval_bit
from eval_rrt import eval_rrt
import numpy as np
from environment import UR5Env, Kuka2Env, SnakeEnv, MazeEnv, KukaEnv
from str2env import str2env
from str2name import str2name
import warnings
warnings.filterwarnings('ignore')
indexes = np.arange(2000, 3000)
try:
    if result_total is None:
        result_total = {}
except:
    result_total = {}
    
methods = [('GNN', eval_gnn), ('GNN_pure', eval_gnn_pure)]

for name, eval_method in methods:

    for problem in ['maze2easy', 'maze2hard', 'snake7', 'ur5', 'kuka7', 'kuka13', 'kuka14']:

        env, indexes = str2env(problem)
        for seed in [1234, 2341, 3412, 4123]:
            if (problem, name, str(seed)) in result_total:
                continue
            print(problem, name, seed)
            result = eval_method(str(env), seed, env, indexes, use_tqdm=True)
            result_total[problem, name, str(seed)] = result

        results = [result_total[problem, name, str(seed)] for seed in [1234, 2341, 3412, 4123]]
        print(problem, name, 'Avg')
        print('success rate:', np.mean([result[0] for result in results]))
        print('collision check: %.2f' % np.mean([result[1] for result in results]))
        print('running time: %.2f' % np.mean([result[2] for result in results]))
        print('path cost: %.2f' % np.mean([result[3] for result in results]))
        print('total time: %.2f' % np.mean([result[4] for result in results]))
        print('')
        result_total[problem, name, 'Avg'] = tuple([np.mean([result[i] for result in results]) for i in range(5)])
        results = []

        pickle.dump(result_total, open("data/new_smoother.p", "wb"))

In [None]:
methods = [('BIT*', eval_bit), ('RRT*', eval_rrt)]

for name, eval_method in methods:

    for problem in ['maze2easy', 'maze2hard', 'snake7', 'ur5', 'kuka7', 'kuka13', 'kuka14']:

        env, indexes = str2env(problem)
        for seed in [1234, 2341, 3412, 4123]:
            if (problem, name, str(seed)) in result_total:
                continue
            print(problem, name, seed)
            result = eval_method(str(env), seed, env, indexes, batch=100, use_tqdm=True)
            result_total[problem, name, str(seed)] = result

        results = [result_total[problem, name, str(seed)] for seed in [1234, 2341, 3412, 4123]]
        print(problem, name, 'Avg')
        print('success rate:', np.mean([result[0] for result in results]))
        print('collision check: %.2f' % np.mean([result[1] for result in results]))
        print('running time: %.2f' % np.mean([result[2] for result in results]))
        print('path cost: %.2f' % np.mean([result[3] for result in results]))
        print('total time: %.2f' % np.mean([result[4] for result in results]))
        print('')
        result_total[problem, name, 'Avg'] = tuple([np.mean([result[i] for result in results]) for i in range(5)])
        results = []

        pickle.dump(result_total, open("data/bit_rrt_new_env.p", "wb"))

In [None]:
import pickle
from importlib import reload
import eval_gnn
reload(eval_gnn)
from eval_gnn import eval_gnn, eval_gnn_pure
from eval_next import eval_next
from eval_bit import eval_bit
from eval_rrt import eval_rrt
import numpy as np
from environment import UR5Env, Kuka2Env, SnakeEnv, MazeEnv, KukaEnv
from str2env import str2env
from str2name import str2name
import warnings
warnings.filterwarnings('ignore')
indexes = np.arange(2000, 3000)
try:
    if result_total is None:
        result_total = {}
except:
    result_total = {}

methods = [('GNN', eval_gnn)]

for name, eval_method in methods:

    for problem in ['maze2easy', 'maze2hard', 'snake7', 'ur5', 'kuka7', 'kuka13', 'kuka14']:

        env, indexes = str2env(problem)
        for batch in [50, 100, 200, 300, 500, 1000]:
            if (problem, name, batch) in result_total:
                continue
            print(problem, name, batch)
            result = eval_method(str(env), 1234, env, indexes, use_tqdm=True, batch=batch, t_max=batch)
            result_total[problem, name, batch] = result

        pickle.dump(result_total, open("data/probing_samples.p", "wb"))

In [None]:
import pickle
from importlib import reload
import eval_gnn
reload(eval_gnn)
from eval_gnn import eval_gnn, eval_gnn_pure
from eval_next import eval_next
from eval_bit import eval_bit
from eval_rrt import eval_rrt
import numpy as np
from environment import UR5Env, Kuka2Env, SnakeEnv, MazeEnv, KukaEnv
from str2env import str2env
from str2name import str2name
import warnings
warnings.filterwarnings('ignore')
indexes = np.arange(2000, 3000)
try:
    if result_total is None:
        result_total = {}
except:
    result_total = {}
    
methods = [('Oracle', eval_gnn)]

for name, eval_method in methods:

    for problem in ['maze2easy', 'maze2hard', 'snake7', 'ur5', 'kuka7', 'kuka13', 'kuka14']:

        env, indexes = str2env(problem)
        if (problem, name) in result_total:
            continue
        print(problem, name)
        result = eval_method(str(env), 1234, env, indexes, use_tqdm=True, smoother='oracle')
        result_total[problem, name] = result

        pickle.dump(result_total, open("data/oracle_test.p", "wb"))

In [3]:
import pickle
from importlib import reload
import eval_gnn
reload(eval_gnn)
from eval_gnn import eval_gnn, eval_gnn_pure
from eval_next import eval_next
from eval_bit import eval_bit
from eval_rrt import eval_rrt
import numpy as np
from environment import UR5Env, Kuka2Env, SnakeEnv, MazeEnv, KukaEnv
from str2env import str2env
from str2name import str2name
import warnings
warnings.filterwarnings('ignore')
indexes = np.arange(2000, 3000)
try:
    if result_total is None:
        result_total = {}
except:
    result_total = {}
    
methods = [('GNN', eval_gnn), ('BIT*', eval_bit)]
seeds = [1234]


for name, eval_method in methods:

    for problem in ['maze2easy', 'maze2hard', 'snake7', 'ur5', 'kuka7', 'kuka13', 'kuka14']:

        env, indexes = str2env(problem)
        for seed in seeds:
            if (problem, name, str(seed)) in result_total:
                continue
            print(problem, name, seed)
            result = eval_method(str(env), seed, env, indexes, use_tqdm=True, batch=100)
            result_total[problem, name, str(seed)] = result

        results = [result_total[problem, name, str(seed)] for seed in seeds]
        print(problem, name, 'Avg')
        print('success rate:', np.mean([result[0] for result in results]))
        print('collision check: %.2f' % np.mean([result[1] for result in results]))
        print('running time: %.2f' % np.mean([result[2] for result in results]))
        print('path cost: %.2f' % np.mean([result[3] for result in results]))
        print('total time: %.2f' % np.mean([result[4] for result in results]))
        print('')
        result_total[problem, name, 'Avg'] = tuple([np.mean([result[i] for result in results]) for i in range(5)])
        results = []

#         pickle.dump(result_total, open("data/new_env.p", "wb"))

gnn 0.01s, search 0.04s, explored 6:   0%|          | 0/1000 [00:00<?, ?it/s]

maze2easy GNN 1234


gnn 0.01s, search 0.02s, explored 4: 100%|██████████| 1000/1000 [01:14<00:00, 13.50it/s] 
  0%|          | 0/1000 [00:00<?, ?it/s]

success rate: 1000
collision check: 581.34
collision check explore: 474.64
running time: 0.07
path cost: 1.31
total time: 72.16
total time explore: 59.28

maze2easy GNN Avg
success rate: 1000.0
collision check: 581.34
running time: 0.07
path cost: 1.31
total time: 72.16

maze2hard GNN 1234


gnn 0.01s, search 0.09s, explored 12:  24%|██▍       | 242/1000 [00:41<02:11,  5.77it/s] 


KeyboardInterrupt: 

In [1]:
import pickle
from importlib import reload
import eval_gnn
reload(eval_gnn)
from eval_gnn import eval_gnn, eval_gnn_pure
from eval_next import eval_next
from eval_bit import eval_bit
from eval_rrt import eval_rrt
import numpy as np
from environment import UR5Env, Kuka2Env, SnakeEnv, MazeEnv, KukaEnv
from str2env import str2env
from str2name import str2name
import warnings
warnings.filterwarnings('ignore')
indexes = np.arange(2000, 3000)
try:
    if result_total is None:
        result_total = {}
except:
    result_total = {}
    
methods = [('GNN', eval_gnn), ('GNN_pure', eval_gnn_pure)]
seeds = [1234, 2341, 3412, 4123]


for name, eval_method in methods:

    for problem in ['maze2easy', 'maze2hard', 'snake7', 'ur5', 'kuka7', 'kuka13', 'kuka14']:

        env, indexes = str2env(problem)
        for seed in seeds:
            if (problem, name, str(seed)) in result_total:
                continue
            print(problem, name, seed)
            result = eval_method(str(env), seed, env, indexes, use_tqdm=True, batch=100, k=10)
            result_total[problem, name, str(seed)] = result

        results = [result_total[problem, name, str(seed)] for seed in seeds]
        print(problem, name, 'Avg')
        print('success rate:', np.mean([result[0] for result in results]))
        print('collision check: %.2f' % np.mean([result[1] for result in results]))
        print('running time: %.2f' % np.mean([result[2] for result in results]))
        print('path cost: %.2f' % np.mean([result[3] for result in results]))
        print('total time: %.2f' % np.mean([result[4] for result in results]))
        print('')
        result_total[problem, name, 'Avg'] = tuple([np.mean([result[i] for result in results]) for i in range(5)])
        results = []

#         pickle.dump(result_total, open("data/gnn_new_env.p", "wb"))

maze2easy GNN 1234


gnn 0.01s, search 0.06s, explored 9:  11%|█▏        | 114/1000 [00:06<00:48, 18.13it/s] 


KeyboardInterrupt: 

In [None]:
another_r = {}
for key, value in result_total.items():
    another_r[key[0], 'Oracle'] = value

In [None]:
pickle.dump(another_r, open("data/oracle_test.p", "wb"))

In [None]:
!python maze_prm.py

In [None]:
pickle.dump(result_total, open("data/new_smoother.p", "wb"))

In [None]:
%debug

In [None]:
result_total.keys()

In [None]:
pickle.dump(result_total, open("data/kuka14.p", "wb"))

In [None]:
with open("data/kuka14.p", "rb") as f:
    result_total = pickle.load(f)

In [None]:
%debug

In [None]:
from eval_gnn import path_cost
save = [path_cost(p1)-path_cost(p2) for p1, p2 in zip(result_total['kuka14', 'GNN', '1234'][-2], result_total['kuka14', 'GNN', '1234'][-1])]

In [None]:
from eval_gnn import path_cost
save2 = [path_cost(r[1])-path_cost(r[2]) for r in replay]

In [None]:
np.mean(save2)

In [None]:
np.mean(save)

In [None]:
bad_prob = np.array(save)==0

In [None]:
azhe = np.logical_and(np.array([len(p1) for p1 in result_total['kuka14', 'GNN', '1234'][-2]])>=4, bad_prob)

In [None]:
np.mean(np.array([path_cost(p1)-path_cost(p2) for p1, p2 in zip(result_total['kuka14', 'GNN', '1234'][-1],  result_total['kuka14', 'BIT*', '1234'][-1])])[azhe])

In [None]:
len(np.where(np.array(save)==0)[0])

In [None]:
fail = [len(p2)==0 for p2 in result_total['kuka14', 'GNN_pure', '4123'][-1]]

In [None]:
pickle.dump(result_total, open("data/kuka14.p", "wb"))

In [None]:
import pickle
pickle.dump(result_total, open("data/snake_gnn_bit_rrt.p", "wb"))

## Train NEXT

In [None]:
!python train_next.py

In [None]:
str(env)

In [None]:
result_total = pickle.load(open('data/new_next.p', 'rb'))

In [None]:
from eval_next import eval_next
from str2name import str2name
from str2env import str2env
import numpy as np
import pickle
import warnings
warnings.filterwarnings('ignore')
try:
    if result_total is None:
        result_total = {}
except:
    result_total = {}

for str_ in ['maze2easy', 'maze2hard', 'snake7', 'ur5', 'kuka7', 'kuka13', 'kuka14']:

    env, indexes = str2env(str_)
    for seed in [1234, 2341, 3412, 4123]:
        if (str_, 'NEXT', str(seed)) in result_total:
            continue
        print(str_, 'NEXT', seed)
        result = eval_next(str(env), seed, env, indexes, use_tqdm=True)
        result_total[str_, 'NEXT', str(seed)] = result

    results = [result_total[str_, 'NEXT', str(seed)] for seed in [1234, 2341, 3412, 4123]]
    print(str_, 'NEXT', 'Avg')
    print('success rate:', np.mean([result[0] for result in results]))
    print('collision check: %.2f' % np.mean([result[1] for result in results]))
    print('running time: %.2f' % np.mean([result[2] for result in results]))
    print('path cost: %.2f' % np.mean([result[3] for result in results]))
    print('total time: %.2f' % np.mean([result[4] for result in results]))
    print('')
    result_total[str_, 'NEXT', 'Avg'] = tuple([np.mean([result[i] for result in results]) for i in range(5)])
    results = []

    pickle.dump(result_total, open("data/new_next.p", "wb"))

In [None]:
from eval_next import eval_next
from str2name import str2name
from str2env import str2env
import numpy as np
import pickle
import warnings
warnings.filterwarnings('ignore')
try:
    if result_total is None:
        result_total = {}
except:
    result_total = {}

seeds = [1234]
    
for str_ in ['maze2easy', 'maze2hard']:

    env, indexes = str2env(str_)
    for seed in seeds:
        if (str_, 'NEXT', str(seed)) in result_total:
            continue
        print(str_, 'NEXT', seed)
        result = eval_next(str(env), seed, env, indexes, use_tqdm=True, t_max=1000)
        result_total[str_, 'NEXT', str(seed)] = result

    results = [result_total[str_, 'NEXT', str(seed)] for seed in seeds]
    print(str_, 'NEXT', 'Avg')
    print('success rate:', np.mean([result[0] for result in results]))
    print('collision check: %.2f' % np.mean([result[1] for result in results]))
    print('running time: %.2f' % np.mean([result[2] for result in results]))
    print('path cost: %.2f' % np.mean([result[3] for result in results]))
    print('total time: %.2f' % np.mean([result[4] for result in results]))
    print('')
    result_total[str_, 'NEXT', 'Avg'] = tuple([np.mean([result[i] for result in results]) for i in range(5)])
    results = []

In [None]:
from eval_gnn import path_cost
costs = []
for i, p in enumerate(result[-1]):
    c = path_cost(p)
    costs.append(c)

In [None]:
np.argmax(costs)

In [None]:
result_total.keys()

In [None]:
!python train_no_obstacle.py

In [None]:
e = pickle.load(open("data/pure.p", "rb"))

In [None]:
for k, v in list(e.items()):
    e[k[0], k[1], str(k[2])] = v
    del e[k]

In [None]:
result_total = e

In [None]:
env_names = ['Kuka_13D', 'Kuka_14D']
envs = [
    KukaEnv(kuka_file="kuka_iiwa/model_3.urdf", map_file="maze_files/kukas_13_3000.pkl"),
    Kuka2Env(),
    ]
indexeses = [np.arange(2000, 3000), np.arange(2000, 3000)]
seeds = [1234, 2341, 3412, 4123]
methods = [eval_gnn]
method_names = ['GNN_pure']
result_total = {}


for env_name, env, indexes in zip(env_names, envs, indexeses):
    for seed in seeds:
        results = []

        print(env_name, 'GNN_pure', seed)
        env, model_explore, model_explore_path, model_smooth, model_smooth_path = str2name(env.__str__(), get_data=False, use_obstacle=False)
        model_explore.load_state_dict(torch.load(model_explore_path, map_location=torch.device("cpu")))
        model_smooth.load_state_dict(torch.load(model_smooth_path, map_location=torch.device("cpu")))
        result = eval_gnn(env.__str__(), seed, env, indexes, model=model_explore, model_s=model_smooth, use_tqdm=True)
        results.append(result)
        result_total[env_name, 'GNN_pure', str(seed)] = result
        pickle.dump(result_total, open("data/pure.p", "wb"))

    print(env_name, 'GNN_pure', 'Avg')
    print('success rate:', np.mean([result[0] for result in results]))
    print('collision check: %.2f' % np.mean([result[1] for result in results]))
    print('running time: %.2f' % np.mean([result[2] for result in results]))
    print('path cost: %.2f' % np.mean([result[3] for result in results]))
    print('total time: %.2f' % np.mean([result[4] for result in results]))
    print('')
    result_total[env_name, 'GNN_pure', 'Avg'] = tuple([np.mean([result[i] for result in results]) for i in range(5)])

In [None]:
result_total = {**e, **result_total}

In [None]:
result_total.keys()

In [None]:
import torch_geometric
from torch_cluster import knn
knn(torch.arange(10), torch.arange(20), 1)

In [None]:
len(result_total['Kuka_14D', 'GNN_pure', '1234'])

In [None]:
with open('data/pure.p', 'wb') as f:
    pickle.dump(result_total, f, pickle.DEFAULT_PROTOCOL)

In [None]:
torch.cuda.is_available()

In [None]:
import plotly.graph_objects as go
import plotly
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt
from collections import defaultdict

envs = ['Maze_2D_Easy', 'Maze_2D_Normal', 'Maze_2D_Hard', 'Kuka_7D', 'Kuka_13D', 'Kuka_14D']
envs_dimension = ['2D Easy', '2D Normal', '2D Hard', '7D', '13D', '14D']
metrics = ['success rate', 'collision check', 'running time', 'path cost', 'total time']
titles = ['Success Rate', 'Collision Check', 'Running Time', 'Path Cost', 'Total Time']
metric_values = {}
method_names = ['GNN', 'BIT*', 'NEXT', 'RRT*' ]
colors = ['rgb(236, 86, 86)', 'orange', '#2ca02c', 'rgb(100, 114, 246)']

plotly.io.orca.config.executable = '/Users/rainorangelemon/anaconda3/envs/pybullet/bin/orca'

# fig = go.Figure(data=[
#     go.Bar(name='GNN', x=envs, y=[1000, 1000, 1000]),
#     go.Bar(name='BIT*', x=envs, y=[1000, 1000, 1000]),
#     go.Bar(name='RRT*', x=envs, y=[958.5, 609.5, 170.25]),
#     go.Bar(name='NEXT', x=envs, y=[999, 992.25, 655.5]),
# ])
# # Change the bar mode
# fig.update_layout(barmode='group', title="Success Rates",
#     xaxis_title="environments",
#     yaxis_title="success rates",)
# fig.show()

with open("data/results/result__.txt", "r") as f:
    for line in f:
        result = line.strip()
        if len(set(result.split(' ')) & set(envs)) and 'Avg' in result:
            key = result.split()[:2]
            for metric in metrics:
                value = float(f.readline().split(': ')[1])
                metric_values[tuple([metric] + key)] = value
print(metric_values)

for metric in metrics:
    for method in method_names:
        metric_values[metric, 'Maze_2D_Easy', method] = (metric_values[metric, 'Maze_2D_Easy', method] + metric_values[metric, 'Maze_2D_Normal', method]) / 2

envs.remove("Maze_2D_Normal")
envs_dimension.remove("2D Normal")

for metric_id, metric_title in enumerate(zip(metrics, titles)):
    metric, title = metric_title
    data = defaultdict(list)
    for env, env_dimension in zip(envs, envs_dimension):
        for method in method_names:
            if metric != 'success rate':
                data[method].append(metric_values[metric, env, method])
            else:
                data[method].append(metric_values[metric, env, method] / 1000)

    fig = go.Figure(
        [go.Bar(name=method, x=envs_dimension, y=data[method], text=data[method], marker_color=color) for color, method in zip(colors, method_names)]
    )
    # Change the bar mode
    # fig.update_traces(texttemplate='%{text:.2f}', textposition='outside')
    fig.layout.margin.autoexpand = False
    fig.layout.margin.t = 10
    fig.layout.margin.b = 40
    fig.layout.margin.l = 60
    fig.update_layout(barmode='group', paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)', title=title,
        title_xanchor='center',
        font_color='black',
        title_x=0.5,
        title_y=0.98,
        xaxis_title="Environments",
        yaxis_title=title,)
    if "success" in metric:
        fig.update_layout(yaxis_range=[0, 1.05])
    if "collision" in metric:
        fig.update_layout(yaxis_range=[0, 5300])
    fig.update_xaxes(showline=True, gridcolor='white', linecolor='rgb(176,176,176)')
    fig.update_yaxes(showline=True, gridcolor='rgb(176,176,176)', linecolor='rgb(176,176,176)', )
    fig.write_image("data/images/%s.pdf" % metric.replace(" ", "_"))
    fig.show()


# print(metric_values.items())
#
# for metric in metrics:
#     for method in method_names:
#         plt.xlabel('environment')
#         plt.title(metric)
#         plt.ylabel(metric)
#         plt.plot(envs_dimension, [v for k, v in metric_values.items() if ((k[2]==method) and (k[0]==metric))], label=method)
#     plt.legend()
#     plt.savefig("data/images/%s.pdf" % metric.replace(" ", "_"))
#     plt.clf()


In [None]:
!conda install -y psutil

In [None]:
import numpy as np
import torch
from environment import KukaEnv, MazeEnv, SnakeEnv, UR5Env
from environment import Kuka2Env
from torch_geometric.nn import knn_graph
from collections import defaultdict
from time import time
import pickle
from tqdm import tqdm
from torch_sparse import coalesce

INFINITY = float('inf')


def construct_graph(env, points, check_collision=True):
    edge_index = knn_graph(torch.FloatTensor(points), k=5, loop=True)
    edge_index = torch.cat((edge_index, edge_index.flip(0)), dim=-1)
    edge_index_torch, _ = coalesce(edge_index, None, len(points), len(points))
    edge_index = edge_index_torch.data.cpu().numpy().T
    edge_cost = defaultdict(list)
    edge_free = []
    neighbors = defaultdict(list)
    for i, edge in enumerate(edge_index):
        if env._edge_fp(points[edge[0]], points[edge[1]]):
            edge_cost[edge[1]].append(np.linalg.norm(points[edge[1]]-points[edge[0]]))
            edge_free.append(True)
        else:
            edge_cost[edge[1]].append(INFINITY)
            edge_free.append(False)
        neighbors[edge[1]].append(edge[0])
    return edge_cost, neighbors, edge_index, edge_free


def min_dist(q, dist):
    """
    Returns the node with the smallest distance in q.
    Implemented to keep the main algorithm clean.
    """
    min_node = None
    for node in q:
        if min_node is None:
            min_node = node
        elif dist[node] < dist[min_node]:
            min_node = node

    return min_node


def dijkstra(nodes, edges, costs, source):
    q = set()
    dist = {}
    prev = {}

    for v in nodes:       # initialization
        dist[v] = INFINITY      # unknown distance from source to v
        prev[v] = INFINITY      # previous node in optimal path from source
        q.add(v)                # all nodes initially in q (unvisited nodes)

    # distance from source to source
    dist[source] = 0
    prev[source] = source

    while q:
        # node with the least distance selected first
        u = min_dist(q, dist)

        q.remove(u)

        for index, v in enumerate(edges[u]):
            alt = dist[u] + costs[u][index]
            if alt < dist[v]:
                # a shorter path to v has been found
                dist[v] = alt
                prev[v] = u

    return dist, prev



data = []
problems = []
env = UR5Env()

with open('data/pkl/ur5_prm_3000.pkl', 'rb') as f:
    old_data = pickle.load(f)

time0 = time()

pbar = tqdm(range(3000))
for problem_index in pbar:

    points, neighbors, edge_cost, edge_index, edge_free = old_data[problem_index]

    source_idxs = np.random.permutation(len(points))
    for source_index in source_idxs:
        dist, prev = dijkstra(list(range(len(points))), neighbors, edge_cost, source_index)
        valid_goal = np.logical_and(np.array(list(dist.values())) != INFINITY, np.array(list(dist.values()))!=0)

        if np.sum(valid_goal) == 0:
            continue

        goal_index = np.array(list(dist.values()))[valid_goal].argmax()
        goal_index = np.arange(len(dist))[valid_goal][goal_index]

        path = []
        parent_idx = goal_index
        while parent_idx != source_index:
            path.append(points[parent_idx])
            parent_idx = prev[parent_idx]
        path.append(points[source_index])
        path.reverse()

        problems.append([env.problems[problem_index][0], points[source_index], points[goal_index], path])
        data.append((points, neighbors, edge_cost, edge_index, edge_free))
        break

with open('data/pkl/problems_ur5_3000.pkl', 'wb') as f:
    pickle.dump([problems[i] for i in index_3000], f, pickle.DEFAULT_PROTOCOL)

with open('maze_files/ur5s_6_3000.pkl', 'wb') as f:
    pickle.dump([problems[i] for i in index_3000], f, pickle.DEFAULT_PROTOCOL)

In [None]:
env = UR5Env()
env.problems[0][0]

In [None]:
with open('data/pkl/problems_ur5_3000.pkl', 'wb') as f:
    pickle.dump(problems, f, pickle.DEFAULT_PROTOCOL)

with open('maze_files/ur5s_6_3000.pkl', 'wb') as f:
    pickle.dump(problems, f, pickle.DEFAULT_PROTOCOL)

# Figure 1

In [None]:
from environment import MazeEnv, UR5Env, SnakeEnv, KukaEnv, Kuka2Env
import torch
from train_next import str2next
from str2name import str2name
from algorithm.tsa import NEXT_plan, RRTS_plan
from config import set_random_seed
from eval_gnn import explore
import numpy as np
from tqdm import tqdm
import pybullet as p
from time import time
from algorithm.bit_star import BITStar
from eval_gnn import path_cost

visualize = False

set_random_seed(3412)
env = KukaEnv(GUI=visualize)
p.resetDebugVisualizerCamera(
    cameraDistance=2.25,
    cameraYaw=-325,
    cameraPitch=-32,
    cameraTargetPosition=[0, 0, 0])
_, model_explore, _, model_smooth, _ = str2name(str(env), load=True)
# env.problems[pb_i][0].pop(3)



    '-------------------------------------------------BIT*----------------------------------------------'

    set_random_seed(3412)
    env.init_new_problem(pb_i)
    c0 = env.collision_check_count
    bit_star = BITStar(env, batch_size=50, T=1000)
    _, _, c, length, _, t = bit_star.plan(float('inf'), 0, 180)
#     print('BIT*', t, env.collision_check_count-c0, length)
    t_bit = t

    '-------------------------------------------------GNN----------------------------------------------'

    set_random_seed(3412)
    c0 = env.collision_check_count
    env.init_new_problem(pb_i)
    c0, cost, data, explored, forward, success, t0, value, path, smooth_path = \
        explore(env, model_explore, model_smooth, t_max=1000, batch=50, k=30)
#     print('GNN', time()-t0, env.collision_check_count-c0, forward, path_cost(smooth_path))
    t_gnn = time()-t0
    
    if t_gnn < t_bit:
        print(pb_i, length, c,  env.collision_check_count-c0, path_cost(smooth_path), success, t_gnn, t_bit)

In [None]:
!python visualize_figure1.py 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('seaborn')

category_names = ['Collision Checking', 'Others']
results = {
    'GNN': [1.7223258018493652, 4.602714538574219-1.7223258018493652],
    'BIT*': [16.37259006500244, 28.814308643341064-16.37259006500244],
    'RRT* (Failed)': [16.088326930999756, 63.17657995223999-16.088326930999756],
    'NEXT (Failed)': [3.7327802181243896, 65.67735195159912-3.7327802181243896],
}


def survey(results, category_names):
    """
    Parameters
    ----------
    results : dict
        A mapping from question labels to a list of answers per category.
        It is assumed all lists contain the same number of entries and that
        it matches the length of *category_names*.
    category_names : list of str
        The category labels.
    """
    labels = list(results.keys())
    data = np.array(list(results.values()))
    data_cum = data.cumsum(axis=1)
    category_colors = plt.get_cmap('RdYlGn')(
        np.linspace(0.15, 0.85, data.shape[1]))

    fig, ax = plt.subplots(figsize=(6.05, 5))
    ax.invert_yaxis()
    ax.xaxis.set_visible(True)
    ax.set_xlim(0, np.sum(data, axis=1).max())

    for i, (colname, color) in enumerate(zip(category_names, category_colors)):
        hatch = '\\' if (colname=='Collision Checking') else ''
        widths = data[:, i]
        starts = data_cum[:, i] - widths
        rects = ax.barh(labels, widths, left=starts, height=0.5,
                        label=colname, color=color, hatch=hatch)

        r, g, b, _ = color
        text_color = 'white' if r * g * b < 0.5 else 'darkgrey'
#         ax.bar_label(rects, label_type='center', color=text_color)
    
    ax.set_yticklabels(results.keys(), fontsize='x-large')
    ax.set_xlabel('Time (s)', fontsize='x-large')
    ax.set_title('Collision Checking vs. Other Queries', loc='center', fontsize='x-large')
    ax.legend(ncol=2, bbox_to_anchor=(1, 1),
          loc='upper right', fontsize='large')
    return fig, ax


fig, ax = survey(results, category_names)
fig.tight_layout()
plt.savefig("cc.pdf", dpi=10)
plt.show()

# Retrain the Explorer

In [None]:
from train_explorer import train_explorer
from str2env import str2env
from model_explore import Explorer

env, _ = str2env('maze2easy')
model = Explorer(env.dim, env.config_dim, 32, 2)
train_explorer(epoch=2000, data_path='./data/pkl/maze2_prm_k_20.pkl', model=model, 
               model_path='data/weights/2d_explorer.pt', env=env, use_obstacle=True, batch=2)

In [None]:
from train_explorer import train_explorer
from str2env import str2env
from model_explore import Explorer

env, _ = str2env('maze2easy')
model = Explorer(env.dim, env.config_dim, 32, 2)
train_explorer(epoch=2000, data_path='./data/pkl/maze2_prm_k_20.pkl', model=model, 
               model_path='data/weights/2d_explorer_pure.pt', env=env, use_obstacle=False, batch=2)

In [None]:
import torch
model.load_state_dict(torch.load('data/weights/2d_explorer_pure.pt'))

In [5]:
from environment import MazeEnv
from eval_gnn import *
import pybullet as p
from time import sleep
env = MazeEnv(dim=2)
_ = eval_gnn_pure(str(env), 1234, env, np.arange(2000, 3000), model=None, model_s=None, use_tqdm=True, smooth=True, batch=100, t_max=1000, k=20)

gnn 0.01s, search 0.03s: 100%|██████████| 1000/1000 [01:27<00:00, 11.43it/s]

success rate: 970
collision check: 582.82
collision check explore: 396.60
running time: 0.06
path cost: 1.54
total time: 86.68
total time explore: 72.16






In [12]:
from eval_gnn import *
_ = eval_gnn_pure(str(env), 1234, env, np.arange(2000, 3000), model=model2, model_s=None, use_tqdm=True, smooth=True, batch=100, t_max=1000, k=20)

gnn 0.00s, search 0.30s: 100%|██████████| 1000/1000 [09:41<00:00,  1.72it/s]

success rate: 615
collision check: 6281.89
collision check explore: 3070.66
running time: 0.70
path cost: 10.83
total time: 538.37
total time explore: 323.21






In [16]:
model2.use_obstacles = False

In [4]:
from eval_gnn import *
_ = eval_gnn_pure(str(env), 1234, env, np.arange(2000, 3000), model=model2.to(device), model_s=None, use_tqdm=True, smooth=True, batch=100, t_max=1000, k=10)

  0%|          | 0/1000 [00:00<?, ?it/s]


NameError: name 'points' is not defined

In [7]:
from eval_gnn import *
_ = eval_gnn(str(env), 1234, env, np.arange(2000, 3000), model=model2.to(device), model_s=None, use_tqdm=True, smooth=True, batch=100, t_max=1000, k=10)

gnn 0.01s, search 0.03s: 100%|██████████| 1000/1000 [00:55<00:00, 18.09it/s]

success rate: 1000
collision check: 500.08
collision check explore: 342.91
running time: 0.05
path cost: 1.35
total time: 54.45
total time explore: 40.48






In [5]:
from eval_gnn import *
_ = eval_gnn(str(env), 1234, env, np.arange(2000, 3000), model=None, model_s=None, use_tqdm=True, smooth=True, batch=100, t_max=1000, k=10)

gnn 0.01s, search 0.03s: 100%|██████████| 1000/1000 [00:56<00:00, 17.65it/s]

success rate: 1000
collision check: 499.48
collision check explore: 342.55
running time: 0.06
path cost: 1.35
total time: 55.83
total time explore: 41.92






In [22]:
model2.use_obstacles = False

In [11]:
from train_explorer import train_explorer
from str2env import str2env
from model_explore import Explorer

env, _ = str2env('ur5')
model2 = Explorer(env.dim, env.config_dim, 32, 6)
# train_explorer(epoch=2000, data_path='./data/pkl/ur5_prm_3000.pkl', model=model2, 
#                model_path='data/weights/ur5_explorer.pt', env=env, use_obstacle=False)

In [1]:
from train_explorer import train_explorer
from str2env import str2env
from model_explore import Explorer

env, _ = str2env('maze2easy')
model2 = Explorer(env.dim, env.config_dim, 128, 2)
# train_explorer(epoch=2000, data_path='./data/pkl/maze_prm.pkl', model=model2, 
#                model_path='data/weights/2d_explorer_128.pt', env=env, use_obstacle=False)

In [4]:
a = model2.labels

In [5]:
a

tensor([[1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 1.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1

In [9]:
import torch
model2 = Explorer(env.dim, env.config_dim, 32, 2)
model2.load_state_dict(torch.load('data/weights/2d_explorer_eta_on_x.pt'))

In [3]:
import os
os.environ['CUDA_LAUNCH_BLOCKING']='1'
%set_env CUDA_LAUNCH_BLOCKING=1

env: CUDA_LAUNCH_BLOCKING=1


In [1]:
from train_explorer import train_explorer
from str2env import str2env
from model import EncoderProcessDecoder

env, _ = str2env('maze2easy')
model = EncoderProcessDecoder(env.dim, env.config_dim, 32, 2)
train_explorer(epoch=2000, data_path='./data/pkl/maze_prm_4000.pkl', model=model, 
               model_path='data/weights/2d_explorer_orig.pt', env=env, use_obstacle=False)

total 2.00, value 0.00, policy 2.00, node 0.00, edge 0.00: 100%|██████████| 2000/2000 [00:31<00:00, 63.05it/s]
total 0.55, value 0.00, policy 0.55, node 0.00, edge 0.00: 100%|██████████| 2000/2000 [00:28<00:00, 70.59it/s]
total 0.61, value 0.00, policy 0.61, node 0.00, edge 0.00: 100%|██████████| 2000/2000 [00:27<00:00, 73.15it/s] 
total 0.76, value 0.00, policy 0.76, node 0.00, edge 0.00: 100%|██████████| 2000/2000 [00:26<00:00, 74.38it/s]
total 0.69, value 0.00, policy 0.69, node 0.00, edge 0.00: 100%|██████████| 2000/2000 [00:26<00:00, 74.17it/s]
total 0.37, value 0.00, policy 0.37, node 0.00, edge 0.00: 100%|██████████| 2000/2000 [00:26<00:00, 75.53it/s]
total 0.54, value 0.00, policy 0.54, node 0.00, edge 0.00: 100%|██████████| 2000/2000 [00:26<00:00, 76.57it/s] 
total 1.08, value 0.00, policy 1.08, node 0.00, edge 0.00: 100%|██████████| 2000/2000 [00:26<00:00, 74.85it/s]
total 0.49, value 0.00, policy 0.49, node 0.00, edge 0.00: 100%|██████████| 2000/2000 [00:26<00:00, 74.92it/s]

In [12]:
model.load_state_dict(torch.load('data/weights/weights_maze.pt'))

<All keys matched successfully>

In [1]:
!ls

 agent.py	        play_batch_size.py
 algorithm	        play.ipynb
 cc.pdf		        play.py
 config.py	        play_smooth.py
 data		        play_train_size.py
 diversify_maze.py      process_batch_txt.py
 environment	        process_probing_samples.py
 eval_2d_explore.py     process_smoother.py
 eval_2d.py	        process_txt.py
 eval_2d_smoother.py    __pycache__
 eval_2d_visualize.py   README.md
 eval_7d_explore.py     replay.py
 eval_7d_hier.py        runs
 eval_7d_mpnet.py       smoother.py
 eval_7d_radius.py      snake_prm.py
 eval_all.py	        stats.py
 eval_batch.py	        str2env.py
 eval_bit.py	        str2name.py
 eval_gnn.py	        test_bit.py
 eval_next.py	        test.py
 eval_only_gnn.py       train_2d_explore.py
 eval_rrt.py	        train_2d.py
 GNN-MP-2d.ipynb        train_2d_smoother.py
 hyperspherical_vae    'train_7d_ backup.py'
 kuka_bit_explore.py    train_7d_bpr.py
 kuka_iiwa	        train_7d_explore.py
 kuka_rrt_explore.py    train_7d_mpnet.py
 kuka_rrt.py	 