In [11]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax.nn import relu, softmax
import mujoco
from mujoco import mjx
import mediapy as media
from MujocoSim import MujocoSim
import numpy as np
import csv
import os


import jax
import jax.numpy as jnp
from jax.numpy.linalg import norm
from jax import grad, jit, vmap
from jax.nn import relu, softmax
import mujoco
from mujoco import mjx
import mediapy as media
import mujoco.mjx
import subprocess
import distutils.util

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

Tue May 28 20:23:38 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.171.04             Driver Version: 535.171.04   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3080 ...    Off | 00000000:01:00.0  On |                  N/A |
| N/A   65C    P0              40W / 150W |   5149MiB / 16384MiB |      6%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

Environment

In [12]:
class MujocoSim:

    def __init__(self):
        # Initialize your environment
        # Make model, data, and renderer
        
        # states
        
        # 0-5: robot arm joints pos
        # 6: finger pos
        # 7-9: box pos
        # 10-13: box quat
        # 14-20: robot arm joints vel
        # 21: finger vel
        # 22-24: box vel
        # 25-27: box angular vel

        # actions
        # 0-5 :robot arm joints
        # 6: finger torque

        self.mj_model = mujoco.MjModel.from_xml_path('simple_arm/scene.xml')
        
        self.mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
        self.mj_model.opt.iterations = 6
        self.mj_model.opt.ls_iterations = 6
        self.mj_data = mujoco.MjData(self.mj_model)
        # renderer = mujoco.Renderer(mj_model)
        # weight_load_target_dist_reward = 1
        # weight_tip_to_load_position_reward = 1
        # weight_tip_to_load_velocity_reward = 1
        # weight_current_torque_cost= 1
        # weight_peak_torque_cost= 1
        # weight_timestep = 1
        self.weights=jnp.array([-1,-45,0,1,1,1]).transpose()
        self.load_dest=jnp.array([1,1,1]).transpose()

        self.max_allowable_distance=4
        self.max_allowable_target_error=0.1
        # self.peak_torque=0

        self.mjx_model = mjx.put_model(self.mj_model)
        self.mjx_data = mjx.put_data(self.mj_model, self.mj_data)
        self.mjx_data.replace(qpos=jnp.array([16e-6, -0.0007421, -0.047, 0.06 ,-4e-05 , 2.33e-5, 0.0009 ,0 ,3, 0 ,0.0198922, 1 ,0, 0 ,0]))


        # self.p=jnp.zeros([3,6]) #TODO 
        # self.J=jnp.zeros([3,6])
    def reset(self):
        # Reset the environment to the initial state
        self.mj_model = mujoco.MjModel.from_xml_path('simple_arm/scene.xml')
        self.mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
        self.mj_model.opt.iterations = 6
        self.mj_model.opt.ls_iterations = 6
        self.mj_data = mujoco.MjData(self.mj_model)
        # renderer = mujoco.Renderer(mj_model)

        self.weights=jnp.array([1,1,1,1,1,0.1]).transpose()
        self.load_dest=jnp.array([1,1,1]).transpose()



        self.mjx_model = mjx.put_model(self.mj_model)
        self.mjx_data = mjx.put_data(self.mj_model, self.mj_data)
        self.mjx_data.replace(qpos=jnp.array([16e-6, -0.0007421, -0.047, 0.06 ,-4e-05 , 2.33e-5, 0.0009 ,0 ,5, 0 ,0.0198922, 1 ,0, 0 ,0]))
    

        
        return self.get_state(self.mjx_data)
    
    # @jax.vmap
    def step(self, model, data, action):
        # Execute the action and return the new state and reward
        # rng = jax.random.PRNGKey(0)
        # rng = jax.random.split(rng,1024)


        # fun=lambda rng: mjx_data.replace(ctrl=jax.random.uniform(rng, (8,)))
        # fun_vmapped = jax.vmap(fun)
        # batch=fun_vmapped(rng)
        data.replace(ctrl=action)
        for i in range(10):
            mjx.step(model, data)
        data=mjx.step(model, data)

        state=self.get_state(data)
        reward= self.get_reward(state,action)
        # self.peak_torque=jnp.max(jnp.array([self.peak_torque, norm(jnp.array([action[0:6]]))**2]))
        return state, reward, data
           
        # fun_vmapped = jax.vmap(step, in_axes=(None,0,0))
        # batch=fun_vmapped(mjx_model, batch,rng)
        # jit_step = jax.jit(jax.vmap(step, in_axes=(None, 0,0)))
        # batch = jit_step(mjx_model, batch,rng)


    def batch_step(self, states, actions):
        batch = self.step(self.mjx_model, self.mjx_data, actions)
    
    def get_reward(self, state, action):
        #update peak torque(
        
        return sum([self.weights[0]*norm(state[7:10]-self.load_dest), 
                    self.weights[1]*norm(state[7:10]-self.mjx_data.geom_xpos[16]),
                    self.weights[2]*norm(self.mjx_data.geom_xpos[16]-state[22:25]) ,
                    self.weights[3]*norm(action[0:6])**2 ,
                    self.weights[4]*1])
        
    def get_state(self,data):

        state=jnp.concatenate([data.qpos[0:7],data.qpos[7:14],data.qvel[0:7], data.qpos[7:13]])
        return state
    def isnt_done(self,state):
        rb = jnp.array(state[7:10]).transpose()
        rd=self.load_dest
        rm=jnp.array([0,0,0]).transpose()

        a=max([self.max_allowable_distance-norm(rb-rm),0]) 
            
        b=max([norm(rb-rd)-self.max_allowable_target_error,0])
        return a*b

Neural Networks

In [13]:
def clip_grads(grads, max_norm):
    norm = jnp.sqrt(sum(jnp.sum(g ** 2) for g in grads.values()))
    clip_coef = jnp.minimum(1.0, max_norm / (norm + 1e-6))
    return {k: v * clip_coef for k, v in grads.items()}
# Network definitions
@jit
def actor_network(params, state):

    # print(f"state {state}, w1 {params['W1']}, b1 {params['b1']}")
    
    hidden = relu(jnp.dot(state, params['W1']) + params['b1'])
    logits = jnp.dot(hidden, params['W2']) + params['b2']
    return softmax(logits)

@jit
def critic_network(params, state):

    hidden = relu(jnp.dot(state, params['W1']) + params['b1'])
    value = jnp.dot(hidden, params['W2']) + params['b2']
    return value


def initialize_params(input_dim, hidden_dim, output_dim):
    params = {
        'W1': jnp.array(np.random.randn(input_dim, hidden_dim) * 0.01),
        'b1': jnp.zeros(hidden_dim),
        'W2': jnp.array(np.random.randn(hidden_dim, output_dim) * 0.01),
        'b2': jnp.zeros(output_dim)
    }
    return params

input_dim = 27  # Example input dimension
hidden_dim = 128
output_dim_actor = 7  # Number of actions
output_dim_critic = 1  # Single value output

actor_params = initialize_params(input_dim, hidden_dim, output_dim_actor)
critic_params = initialize_params(input_dim, hidden_dim, output_dim_critic)

Actor Critic

In [14]:
class ActorCritic:
    def __init__(self, env : MujocoSim, actor_params : dict, critic_params : dict, lr : float=0.01):
        self.env = env
        self.actor_params = actor_params
        self.critic_params = critic_params
        self.lr = lr

    # @jit
    def select_action(self, state,actor_params):
        return actor_network(actor_params, state)
        # return jnp.ones([7,1])
    # @jit
    def update(self, state, action, reward, next_state, actor_params, critic_params):
        gamma = 0.99  # Discount factor
        max_grad_norm = 1.0  # Adjust as needed

        # Compute TD target
        value = critic_network(critic_params, state)
        next_value = critic_network(critic_params, next_state)
        td_target = reward + gamma * next_value
        td_error = td_target - value


        # print("tderr", td_error)
        # print("value", value)
        # print("reward",reward)
        # Update critic
        def critic_loss(params):
            value = critic_network(params, state)
            return jnp.mean((td_target - value) ** 2)

        # tde=(reward + 0.9 * critic_network(critic_params, next_state) - critic_network(critic_params, state)[0])


        critic_grads = grad(critic_loss)(critic_params)
        critic_grads = clip_grads(critic_grads, max_grad_norm)

        critic_params = {k: critic_params[k] + self.lr * critic_grads[k] for k in critic_params}


        key = jax.random.PRNGKey(0)
        std_devs=jnp.ones([1,7])*50



        def actor_loss(params):
              
            pi=jnp.abs(jax.random.normal(key,(1,7))*std_devs+actor_network(actor_params, state)[0])

            pi=jnp.where(pi>0.01, pi, 0.01)
            log_pi = jnp.log(pi)

            # print("logpi",log_pi)
            return jnp.mean(log_pi * td_error)
      
        actor_grads = grad(actor_loss)(actor_params)
        actor_grads = clip_grads(actor_grads, max_grad_norm)

        # print("actorgrad",actor_grads)
        # print("acc", actor_grads)
        # for k in actor_params:
        #     print(k)
        #     for i in actor_params[k]:
        #         print(i)
        pi = actor_network(actor_params, state)   

        actor_params = {k: actor_params[k]  + self.lr * actor_grads[k] for k in actor_params}
        # actor_params['b2']=actor_params['b2']*pi
        # pi = actor_network(actor_params, state)   
        
        # print(f"pi: {pi}, log_pi{ jnp.log(jnp.where(pi != 0., pi, 0.0001))},")
        return actor_params, critic_params


    def batch_train(self, episodes=20, batch_size=4096):
        
        select_action=jax.vmap(self.select_action, in_axes=(0,None))
        
   
        update=jax.jit(jax.vmap(self.update, in_axes=(0,0,0,0,None,None)))
        # update_vmap=self.update
  

        batch=jax.vmap(lambda rng: self.env.mjx_data.replace(ctrl=jax.random.uniform(rng, (8,))))(jax.random.split(jax.random.PRNGKey(0),batch_size))
        get_state=jax.vmap(self.env.get_state)
        step=jax.jit(jax.vmap(self.env.step, in_axes=(None,0,0)))
        # step=jax.vmap(self.env.step, in_axes=(None,0,0))
        for episode in range(episodes):
            state = get_state(batch)
            #done = false
            action = select_action(state, self.actor_params)
            next_state, reward, batch = step(self.env.mjx_model, batch, action)
            if episode==0:
              self.log_header(jnp.mean(state, axis=0), jnp.mean(action, axis=0), jnp.mean(reward, axis=0))
            
            #visualize
            # renderer = mujoco.Renderer(self.env.mj_model)
            # frames=[]
            # framerate=100
            print("ep", episode)
            for i in range(1000):
                # print(i)
                action = select_action(state, self.actor_params)

                next_state, reward, batch = step(self.env.mjx_model, batch, action)
                
                # jax.debug.print(f"next_state.shape {next_state.shape} mean {jnp.mean(next_state, axis=0)}")


                actor_params,  critic_params = update(state, action, reward, next_state, self.actor_params, self.critic_params)
                for key in actor_params:
                    self.actor_params[key]=jnp.mean(actor_params[key], axis=0)
                    self.critic_params[key]=jnp.mean(critic_params[key], axis=0)
                # for i in range(batch_size):    
                #      self.actor_params,  self.critic_params = update(state[i], action[i], reward[i], next_state[i], self.actor_params, self.critic_params)
                #     for key in actor_params:
                #         self.actor_params[key]=actor_params[key]

                #     for key in critic_params:
                #         self.critic_params[key]=critic_params[key]
                state = next_state
                
    
    

                # self.log_line(jnp.mean(c, axis=0), jnp.mean(action, axis=0), jnp.mean(reward, axis=0))
                # batched_mj_data = mjx.get_data(self.env.mj_model, batch)    
                # mj_data=batched_mj_data[0]
                # renderer.update_scene(mj_data)
                # pixels = renderer.render()
                # frames.append(pixels)
            # media.show_video(frames, fps=framerate)
                
                # batched_mj_data = mjx.get_data(self.env.mj_model, batch)

    def train(self, episodes=1):
   
        reset=jax.jit(self.env.reset())

        for episode in range(episodes):
            reset()
            batch=self.env.mjx_data.replace(ctrl=jnp.ones([8]))
            
            state = jnp.zeros(27)

            done = False
            action = self.select_action(state, self.actor_params)
            next_state, reward, batch = self.env.step(self.env.mjx_model, batch, action)
            if episode==0:
              self.log_header(state,action,reward)
            renderer = mujoco.Renderer(self.env.mj_model)


            frames=[]
            framerate=100
            print("ep", episode)
            for i in range(10):

                action = self.select_action(state, self.actor_params)

                
                next_state, reward, batch = self.env.step(self.env.mjx_model, batch, action)

                # jax.debug.print(f"next_state.shape {next_state.shape} mean {jnp.mean(next_state, axis=0)}")
                for i in range(1):    
                    actor_params, critic_params = self.update(state, action, reward, next_state, self.actor_params, self.critic_params)
                    for key in actor_params:
                        self.actor_params[key]=actor_params[key]

                    for key in critic_params:
                        self.critic_params[key]=critic_params[key]
                state = next_state
                
    

                print(f"action: {action}")
                self.log_line(state,action,reward)
                

    # def update_parameters(self, mm, critic_params):

    def log_header(self,state, action, reward):
        header_text=[]
        for i in range(len(state)):
            header_text.append("state"+"_"+str(i) )
        for i in range(len(action)):
            header_text.append("action"+"_"+str(i))
        header_text.append("reward")
        filename="ac_log.csv"

        
        
        user_input = 'y'
        if user_input == 'y': 
            file = 'logs/'+filename
            try:
                os.remove (file)
                data_f = open('logs/'+filename, 'a',newline='')
            except FileNotFoundError:
                data_f = open('logs/'+filename, 'x',newline='')
            # data_f = open('../opy_logs/'+filename, 'a',newline='')
            self.data_writer = csv.writer(data_f)
            Headers = header_text
            print(Headers)
            self.data_writer.writerow(Headers) 
        else:
            data_f = open('logs/'+filename, 'a',newline='')
            data_writer = csv.writer(data_f)
    def log_newline(self):    
        self.log_text=self.log_text+"0\n"        
    def log_line(self, state, action, reward):

        LogList=[]
        # print(f"next_state.shape {state.shape} mean {str(state)}")
        # LogList.append(state)
        for i in state:
            LogList.append(str(i)) 
        for i in action:
            LogList.append(str(i)) 
        LogList.append(reward) 
        self.data_writer.writerow(LogList) 
    def save_parameters(self):
        LogList=[]
        # print(f"next_state.shape {state.shape} mean {str(state)}")
        
        filename="params.csv"
        file = 'logs/'+filename
        try:
            os.remove (file)
            data_f = open('logs/'+filename, 'a',newline='')
        except FileNotFoundError:
            data_f = open('logs/'+filename, 'x',newline='')
        # LogList.append(state)
        data_writer = csv.writer(data_f)
        for key in self.actor_params:
            LogList.append(key)
            for item in self.actor_params[key]:
                LogList.append(str(item))
        # print(f"next_state.shape {state.shape} mean {str(state)}")
        # LogList.append(state)
        for key in self.critic_params:
            LogList.append(key)
            for item in self.actor_params[key]:
                LogList.append(str(item))
        
        data_writer.writerow(LogList)
        # Vectorized operations using vmap
        # v_update = vmap(self.update, in_axes=(0, 0, 0, 0))
        # v_update(states, actions, rewards, next_states)
        

    # @jit
    # def batch_select_action(self, states):
    #     # Vectorized action selection
    #     v_select = vmap(self.select_action)
    #     return self.v_se lect(states)
    def use_parameters(self):
        self.env.reset()
        if(1):
            batch=self.env.mjx_data.replace(ctrl=jnp.ones([8]))
            
            state = jnp.zeros(27)

            done = False
            action = self.select_action(state, self.actor_params)
            next_state, reward, batch = self.env.step(self.env.mjx_model, batch, action)
            if 1:
              self.log_header(state,action,reward)
            renderer = mujoco.Renderer(self.env.mj_model)


            frames=[]
            framerate=100
            step=jax.jit(self.env.step)
            print("use_parameters", 1)
            for i in range(100):

                action = self.select_action(state, self.actor_params)

                
                next_state, reward, batch =step(self.env.mjx_model, batch, action)

                # jax.debug.print(f"next_state.shape {next_state.shape} mean {jnp.mean(next_state, axis=0)}")
                # for i in range(1):    
                #     actor_params, critic_params = self.update(state, action, reward, next_state, self.actor_params, self.critic_params)
                #     for key in actor_params:
                #         self.actor_params[key]=actor_params[key]

                #     for key in critic_params:
                #         self.critic_params[key]=critic_params[key]
                state = next_state
                # self.log_line(jnp.mean(state, axis=0), jnp.mean(action, axis=0), jnp.mean(reward, axis=0))
                mj_data = mjx.get_data(self.env.mj_model, batch)    
                # mj_data=batched_mj_data[0]
                renderer.update_scene(mj_data)
                pixels = renderer.render()
                frames.append(pixels)
                self.log_line(state,action,reward)               
            media.show_video(frames, fps=framerate)

Running

In [16]:

# Instantiate your environment
env = MujocoSim()

# Create ActorCritic instance
ac = ActorCritic(env, actor_params, critic_params)

# # Train the model
print("begin")
# ac.batch_train()
batch_train=jax.jit(ac.batch_train)
batch_train()




print("save params")
ac.save_parameters()
print("visualize one run")
ac.use_parameters()


begin
['state_0', 'state_1', 'state_2', 'state_3', 'state_4', 'state_5', 'state_6', 'state_7', 'state_8', 'state_9', 'state_10', 'state_11', 'state_12', 'state_13', 'state_14', 'state_15', 'state_16', 'state_17', 'state_18', 'state_19', 'state_20', 'state_21', 'state_22', 'state_23', 'state_24', 'state_25', 'state_26', 'action_0', 'action_1', 'action_2', 'action_3', 'action_4', 'action_5', 'action_6', 'reward']
ep 0
ep 1
ep 2
ep 3
ep 4
ep 5
ep 6
ep 7
ep 8
ep 9
ep 10
ep 11
ep 12
ep 13
ep 14
ep 15
ep 16
ep 17
ep 18
ep 19


In [None]:
# ac.use_parameters()
# print("save params")
# ac.save_parameters()
# print("visualize one run")
# ac.use_parameters()

save params


NameError: name 'ac' is not defined