In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time

from absl import app
from absl import flags
import collections

from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import common
from dm_control.suite.utils import randomizers
from dm_control.utils import containers

import scipy.io as sio
import matplotlib
matplotlib.use("Agg")
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np

In [9]:
_DEFAULT_TIME_LIMIT = 20
_CONTROL_SUBSTEPS = 100 #for use with control.Environment(... n_sub_steps=#)
CONVERSION_LENGTH = 1000 #ASSUMES .MAT TO BE IN MILLIMETERS

fileName='.\\demos\\ratMocapImputed'
varName='markers_preproc'
max_frame=33000
start_frame=24000
frame_step=100
model_filename='rodent_mocap_tendon.xml'

In [3]:
def get_model_and_assets():
  """Returns a tuple containing the model XML string and a dict of assets."""
  return common.read_model(model_filename), common.ASSETS

In [10]:
class jeffRat(base.Task):
    def __init__(self, random=None):
        super(jeffRat, self).__init__(random=random)
    def initialize_episode(self, physics):
        penetrating = True
        while penetrating:
            randomizers.randomize_limited_and_rotational_joints(
                physics, self.random)
            physics.after_reset()
            penetrating = physics.data.ncon > 0
    def get_observation(self, physics):
        obs = collections.OrderedDict()
        obs['joint_angles'] = physics.joint_angles()
        return obs
    def mocap(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
        physics = mujoco.Physics.from_xml_string(*get_model_and_assets())
        task = jeffRat(random=random)
        environment_kwargs = environment_kwargs or {}
        return control.Environment(
            physics, task, time_limit=time_limit, n_sub_steps=_CONTROL_SUBSTEPS,
            **environment_kwargs)
    def get_reward(self, physics):
        return 0

In [11]:
def parse(fileName, varName):
  parsed = collections.namedtuple('parsed', ['marks', 'bods', 'medianPose', 'mocap_pos'])
  marks = dict()
  medianPose = dict()
  zOffset = dict()
  values = sio.loadmat(fileName, variable_names = varName)
  bods = list(values[varName].dtype.fields)
  for bod in bods:
    marks[bod] = values[varName][bod][0][0]/CONVERSION_LENGTH #ASSUMES .MAT TO BE IN MILLIMETERS
    zOffset[bod] = np.amin(marks[bod], axis = 0) * [0, 0, 1]
  if min(zOffset[min(zOffset)]) < 0:
    for bod in bods:
      marks[bod] += abs(zOffset[min(zOffset)])
  for bod in bods:
    if bod == bods[0]:
      mocap_pos = marks[bod]
    else:
      mocap_pos = np.concatenate((mocap_pos, marks[bod]), axis = 1)
    medianPose[bod] = np.median(marks[bod], axis = {0})
  mocap_pos.shape = (mocap_pos.shape[0], int(mocap_pos.shape[1]/3), 3)
  mocap_pos = forFillNan(mocap_pos)
  return parsed(marks, bods, medianPose, mocap_pos)

In [12]:
def forFillNan(arr):
    for m in range(arr.shape[1]):
        if any(np.isnan(arr[:, m, 1])):
            ind = np.where(np.isnan(arr[:, m, 1])) #only works if all coordinates are NaNs    
            for i in ind[0]:
                if i == 0:
                    arr[i, m, :] = [0, 0, 0]
                else:
                    arr[i, m, :] = arr[i - 1, m, :].copy()
    return arr

In [13]:
def showFrame(data, frame, i):
    p_i = data.mocap_pos[i, :].copy()
    with env.physics.reset_context():
        env.physics.data.mocap_pos[:] = p_i
    while env.physics.time() < 2. or np.nanmean(abs(env.physics.data.qvel)) > 1e-06:
        env.physics.step()
    frame = np.hstack([env.physics.render(height, width, camera_id="front_side")])
    img = plt.imshow(frame)
    plt.waitforbuttonpress()
    plt.close()

In [14]:
def getFrame(data, i):
    p_i = data.mocap_pos[i, :].copy()
    with env.physics.reset_context():
        env.physics.data.mocap_pos[:] = p_i
    while env.physics.time() < 2. or np.nanmean(abs(env.physics.data.qvel)) > 1e-06:
        env.physics.step()
    return np.hstack([env.physics.render(height, width, camera_id="front_side")])    

In [15]:
env = jeffRat.mocap()
data = parse(fileName, varName)
metadata = dict(title= model_filename + ': ' + fileName, artist='Jeff Rhoades/Jesse Marshall/DeepMind',
                  comment=varName)
writer = animation.FFMpegWriter(fps=3, metadata=metadata)
width = 500
height = 500

  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
  r = func(a, **kwargs)


In [16]:
max_frame = min(max_frame, data.mocap_pos.shape[0])
max_num_frames = (max_frame - start_frame)//frame_step
video = np.zeros((max_num_frames, height, width, 3), dtype=np.uint8)
qVid = np.zeros((max_num_frames, env.physics.data.qpos.size), dtype=np.uint8)

In [17]:
with env.physics.reset_context():
      env.physics.data.qpos[:] = env.physics.model.key_qpos

In [19]:
for i in range(start_frame, max_frame, frame_step):
    p_i = data.mocap_pos[i, :].copy()
    with env.physics.reset_context():
      env.physics.data.mocap_pos[:] = p_i
    while env.physics.time() < 2. or np.nanmean(abs(env.physics.data.qvel)) > 1e-06:
      env.physics.step()

    video[(i - start_frame)//frame_step] = np.hstack([env.physics.render(height, width, camera_id="front_side")])
    qVid[(i - start_frame)//frame_step] = env.physics.data.qpos[:]

In [None]:
fig = plt.figure()
img = plt.imshow(video[1])

In [None]:
plt.rcParams['animation.ffmpeg_path'] = r'C:\Users\RatControl\ffmpeg\bin\ffmpeg.exe'

In [None]:
with writer.saving(fig, fileName + "_vid.mp4", dpi=None):
    for i in range(video.shape[0]):
      if i == 0:
        img = plt.imshow(video[i])
      else:
        img.set_data(video[i])
      plt.draw()
      writer.grab_frame()
writer.finish()

In [27]:
env.physics.data.mocap_pos

array([[ 9.73277926e-02, -1.11613230e-02,  1.08734119e-01],
       [ 7.94162067e-02, -1.62186141e-02,  1.15701790e-01],
       [ 7.43458132e-02,  1.07417317e-02,  1.27369634e-01],
       [ 3.37840843e-02,  1.00000000e-06,  9.40371937e-02],
       [ 1.00000000e-06,  1.00000000e-06,  1.05950729e-01],
       [-5.11375357e-02, -4.01853545e-03,  8.53570349e-02],
       [ 1.37338658e-02,  2.50540572e-02,  9.57915201e-02],
       [-2.31188586e-02,  2.82226992e-02,  9.61352321e-02],
       [-6.81922041e-02,  1.39842433e-02,  7.75000700e-02],
       [-3.79932853e-02, -2.24907451e-02,  6.62535116e-02],
       [ 3.46615604e-02,  1.76026441e-02,  2.60562705e-02],
       [ 4.42273303e-02,  1.71157019e-02,  1.90648961e-02],
       [ 3.34544863e-02,  2.18996579e-02,  9.02344340e-02],
       [ 3.46163425e-02, -1.30273278e-02,  7.64685416e-02],
       [ 2.81739495e-02, -1.05859360e-02,  3.28679493e-02],
       [ 3.61650649e-02, -7.91491522e-03,  2.27443101e-02],
       [-2.24588784e-02, -1.71874240e-02