In [None]:
from rl.trainer import Trainer

trainer = Trainer(10, 2)

In [None]:
import getpass

user = getpass.getuser()

In [None]:
from omni.isaac.kit import SimulationApp

simulation_app = SimulationApp({"headless": True, "open_usd": f"omniverse://localhost/Users/{user}/test_jetbot.usd", 
                                "livesync_usd": f"omniverse://localhost/Users/{user}/test_jetbot.usd"})


In [None]:
import omni
from omni.isaac.core import World
from pxr import Gf, Sdf, UsdGeom, UsdShade

import torch
import numpy as np

In [None]:
world = World(physics_dt= 1.0 / 30, backend = "torch")

In [None]:
from omni.isaac.core.utils.stage import set_stage_up_axis

set_stage_up_axis("y")
world.get_physics_context().set_gravity(-9.81)

In [None]:
from omni.isaac.core.utils.stage import get_current_stage
UsdGeom.GetStageUpAxis(get_current_stage())

In [None]:
world.get_physics_context()._physics_scene.GetGravityDirectionAttr().Get()

In [None]:
# get prims
len(list(world.scene.stage.TraverseAll()))

In [None]:
world.render()

In [None]:
from rl.robot_env import RobotEnv

env = RobotEnv("/World/envs/*/jetbot", [10.0, 10.0])

In [None]:
world.reset()

In [None]:
env.start()
world.scene.add(env.robots)

In [None]:
total_step = 0

In [None]:
action_shape = env.robots._default_joints_state.positions.shape

In [None]:
# step
def step(warm_up_steps = 1000):
    env.progress_buf += 1
    if total_step < warm_up_steps:
        actions = 10 * (2 * torch.rand(action_shape) - 1)
    else:
        actions = 10 * trainer.sample_action(current_obs.to(trainer.device))
        
    env.robots.set_joint_velocities(actions)
    
    return actions

In [None]:
# observation
def get_obs():
    torso_position, torso_rotation  = env.robots.get_world_poses()

    dof_pos = env.robots.get_joint_positions()
    dof_vel = env.robots.get_joint_velocities()

    obs = torch.cat([torso_position[...,[1,2]], torso_rotation, dof_pos, 0.1 * dof_vel], dim = 1)

    return obs

In [None]:
def get_reward_done():
    torso_position, _  = env.robots.get_world_poses()
    reward = 10 * torso_position[...,2].clone()
    
    done = torch.where(torso_position[...,2] < -0.5, 1.0, 0.0)
    done = torch.where(env.progress_buf > 1000, torch.ones(done.shape[0]), done)
    
    return reward, done

In [None]:
current_obs = get_obs().data.clone()
reward, done = get_reward_done()

In [None]:
current_obs, reward, done

In [None]:
# need to reset
def reset(done):
    if torch.sum(done) >= 1:
        env_ids = []
        for i in range(len(done)):
            if done[i] > 0:
                env_ids.append(i)

        env.reset_idx(env_ids)

In [None]:
world.step(render=True)

In [None]:
for i in range(100000):
    total_step += 1
    
    # step
    current_obs = get_obs().clone()
    actions = step()
    # actions = 10 * torch.ones(4, 2)
    # env.robots.set_joint_velocities(10 * torch.ones(4, 2))
    world.step(render=False)
    
    # get obs
    new_obs = get_obs().clone()
    
    # get reward done
    reward, done = get_reward_done()
    
    # reset
    reset(done)
    
    # buffer
    trainer.buf.add_batch(current_obs, 0.1 * actions, new_obs, reward, done)
    # print(total_step, "current_obs, actions, new_obs, reward, done \n\n", current_obs, actions, new_obs, reward, done)
    

    # debug
    if total_step % 1000 == 99:
        print(total_step, "reward", torch.mean(reward).tolist())
        world.render()
    
    # train
    if total_step > 1000:
        trainer.train_debug(batch_size = 32)
    
    

In [None]:
world.reset()

In [None]:
world.render()