# Import Packages

In [7]:
%env MUJOCO_GL=egl
from collections import deque

import matplotlib
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import nengo
import numpy as np
import PIL.Image

%matplotlib widget

# Access to enums and MuJoCo library functions.
from dm_control.mujoco.wrapper.mjbindings import enums
from dm_control.mujoco.wrapper.mjbindings import mjlib

from dm_control import composer, mjcf, mujoco
from dm_control.utils import inverse_kinematics as ik
from IPython.display import HTML
from multiprocessing import Process

env: MUJOCO_GL=egl


# Define Utilities and Parameters

In [8]:
# Rendering parameters
dpi = 100
framerate = 24 # (Hz)
width, height = 720, 480
sensor_shape = (10, 10)

# IK solver parameters
_MAX_STEPS = 50
_TOL = 1e-12


def display_video(frames, framerate=30):
    height, width, _ = frames[0].shape
    orig_backend = matplotlib.get_backend()
    matplotlib.use('Agg')  # Switch to headless 'Agg' to inhibit figure rendering.
    fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi)
    matplotlib.use(orig_backend)  # Switch back to the original backend.
    ax.set_axis_off()
    ax.set_aspect('equal')
    ax.set_position([0, 0, 1, 1])
    im = ax.imshow(frames[0])
    def update(frame):
      im.set_data(frame)
      return [im]
    interval = 1000/framerate
    anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,
                                   interval=interval, blit=True, repeat=False)
    return HTML(anim.to_html5_video())


def get_jacobian(physics, site_name):
    phys = mujoco.Physics.from_model(physics.model)
    jac_pos = np.zeros((3, phys.model.nv))
    jac_rot = np.zeros((3, phys.model.nv))
    mjlib.mj_jacSite(
        phys.model.ptr,
        phys.data.ptr,
        jac_pos,
        jac_rot,
        phys.model.name2id(site_name, 'site'))
    
    return jac_pos, jac_rot

# Run Simulations

In [18]:
# Load scene
scene_xml = 'models/scene.xml'
physics = mujoco.Physics.from_xml_path(scene_xml)

# Visualize initial pose
random_state = np.random.RandomState(1024)
pixels = []
ncam = physics.model.ncam
cameras = random_state.choice(ncam, ncam, replace=False)
for camera_id in cameras:
    pixels.append(physics.render(camera_id=camera_id, width=width, height=height))
image = np.hstack((np.hstack(pixels[:ncam-1]), np.hstack(pixels[ncam-1:])))

# Define simulation variables
site_name = 'attachment_site'
control_site = physics.data.site(name=site_name)
target_site = physics.data.site(name='reach_site2')
joint_names = ['joint{}'.format(i+1) for i in range(7)]
ctrl = np.zeros(10)
duration = 1.0 # (seconds)
omega = np.pi/2 # Rotator angular velocity
video = []
sensor_data = []

# Simulate, saving video frames
physics.reset(0)
ctrl[:7] = physics.data.qpos[:7]
xpos = control_site.xpos
while physics.data.time < duration:
    if control_site.xpos[2] >= 0.3:
        jac_pos, jac_rot = get_jacobian(physics, site_name)
        inv_J = np.linalg.pinv(jac_pos)
        dq = np.dot([0, 0, -0.05], jac_pos)
        ctrl[:7] += dq[:7]
    ctrl[-3:] = omega*physics.data.time
    physics.set_control(ctrl)
    physics.step()

    # Save video frames and sensor data
    if len(video) < physics.data.time * framerate:
        pixels = physics.render(camera_id='prospective', width=width, height=height)
        video.append(pixels.copy())
        
        data = physics.data.sensordata
        sensor_data.append(data.reshape(sensor_shape))

# PIL.Image.fromarray(image)
display_video(video, framerate)

# Online Rendering with Sensor Visualization

In [None]:
# Initialize animation
pixels = physics.render(camera_id='prospective', width=width, height=height)
video = [pixels]
data = physics.data.sensordata.reshape(shape)
sensor_data = [data]

fig, axs = plt.subplots(1, 2, figsize=(2*width / dpi, height / dpi), dpi=dpi)
axs[0].set_axis_off()
axs[0].set_aspect('equal')
im0 = axs[0].imshow(video[0])
axs[1].set_xlabel("neuron index")
axs[1].set_ylabel("neuron index")
axs[1].set_title("Touch Sensor")
axs[1].set_xticks(list(range(10)))
axs[1].set_yticks(list(range(10)))
im1 = axs[1].imshow(sensor_data[0])
    
    
def update(data):
    im0.set_data(data[0])
    im1.set_data(data[1])
    return [im0, im1]


def data_gen():
    while physics.data.time - t0 < duration:
        # Inject controls and step the physics.
        t = physics.data.time
        ctrl[-3:] = omega*t
    
        physics.set_control(ctrl)
        physics.step()
        pixels = physics.render(camera_id='prospective', width=width, height=height)
        data = physics.data.sensordata.reshape(shape)
    
        # Save video frames
        if len(video) < physics.data.time * framerate:
            video.append(pixels.copy())
            sensor_data.append(data)
        
        yield pixels, data

t0 = physics.data.time
anim = animation.FuncAnimation(fig=fig, func=update, frames=data_gen, interval=1000/framerate,  blit=True)
plt.tight_layout()
plt.show()

# Generate Strings to Define Sensor Sites and Touch Sensors to Inject into The XML file

In [None]:
n = 10
dx = 3e-3
offset = dx*(n-1)/2.0
for i in range(n*n):
    print('{}<site name="taxel_site{}" type="capsule" size=".001 .001" pos="{:.5f} -.002 {:.5f}" euler="1.570796 0 0"/>'.format(' '*24, i+1, dx*(i//n)-offset, dx*(i%n)-offset))
for i in range(n*n):
    print('{}<touch name="taxel{}" site="taxel_site{}"/>'.format(' '*4, i+1, i+1))