# Pretrain Policy with  `MPC dataset`

In [1]:
# %matplotlib
import numpy as np
import pickle
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from mujoco_parser import MuJoCoParserClass
import torch.optim as optim

from tqdm import tqdm
import matplotlib.pyplot as plt
from util import r2quat

from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from policy import GaussianPolicy
from pid import PID_ControllerClass
from motion_vae import MotionVariationalAutoEncoderClass

import mediapy as media

### Load  `MPC dataset`

In [2]:
# with open(file='../data/MPC_dataset_240324_bigggger_noise-2.pkl', mode='rb') as f:
#     dataset = pickle.load(f)
with open(file='../data/SMPL_MPC_dataset_240430_single2.pkl', mode='rb') as f:
    dataset = pickle.load(f)

horizon = 300
action = torch.Tensor(dataset['action'])
root_pos = torch.Tensor(dataset['root_pos'])
root_rot = torch.Tensor(dataset['root_rot'])
root_vel = torch.Tensor(dataset['root_vel'])
root_ang_vel = torch.Tensor(dataset['root_ang_vel'])
dof_pos = torch.Tensor(dataset['dof_pos'])
dof_vel = torch.Tensor(dataset['dof_vel'])
local_key_pos = torch.Tensor(dataset['local_key_pos'])

del dataset

# qpos_batch = qpos_batch[:, :]
# qpos_batch = torch.cat((qpos_batch[:-2, :],qpos_batch[1:-1, :],qpos_batch[2:, :]),dim=1)
# qvel_batch = torch.cat((qvel_batch[:-2, :],qvel_batch[1:-1, :],qvel_batch[2:, :]),dim=1)

obs = torch.cat((root_pos[:,2:3], root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, local_key_pos.reshape(-1, 12)), dim=1)
# obs = torch.cat((root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, local_key_pos.reshape(-1, 12)), dim=1)
# obs = torch.cat((root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel), dim=1)

# action_batch = action_batch[2:]
obs_dim = obs.shape[1]
action_dim = action.shape[1]

print("obs : ", obs.shape)
print("action : ", action.shape)


obs :  torch.Size([300, 97])
action :  torch.Size([300, 37])


In [3]:
xml_path = '../asset/smpl_rig/scene_smpl_rig.xml'
env = MuJoCoParserClass(name='smplrig',rel_xml_path=xml_path,VERBOSE=True)
# Modify the color of body exluding 'world'
for body_name in env.body_names:
    if body_name in ['world']: continue
    body_idx = env.body_names.index(body_name)
    geom_idxs = [idx for idx,val in enumerate(env.model.geom_bodyid) if val==body_idx]
    for geom_idx in geom_idxs:
        env.model.geom(geom_idx).rgba = [0.3,0.3,0.5,0.5]
print ("Done.")

dt:[0.0083] HZ:[120]
n_dof (=nv):[43]
n_geom:[67]
geom_names:['floor', 'base', 'base2lpelvis', 'base2rpelvis', 'base2spine1', 'spine2spine', 'spine2spine2', 'spine2spine3', 'spine2lcollar', 'spine2rcollar', 'neck2head', 'head', 'nose', 'lcollar2lshoulder', 'lshoulder2lelbow', 'lelbow2lwrist', 'lwrist2lindex1', 'lwrist2lmiddle1', 'lwrist2lpinky1', 'lwrist2lring1', 'lwrist2lthumb1', 'lindex1-lindex2', 'lindex2-lindex3', 'lindex3-lindextip', 'lmiddle1-lmiddle2', 'lmiddle2-lmiddle3', 'lmiddl3-lmiddletip', 'lring1-lring2', 'lring2-lring3', 'lring3-lringtip', 'lpinky1-lpinky2', 'lpinky2-lpinky3', 'lpinky3-lpinkytip', 'lthumb1-lthumb2', 'lthumb2-lthumb3', 'lthumb3-lthumbtip', 'rcollar2rshoulder', 'rshoulder2relbow', 'relbow2rwrist', 'rwrist2rindex1', 'rwrist2rmiddle1', 'rwrist2rpinky1', 'rwrist2rring1', 'rwrist2rthumb1', 'rindex1-rindex2', 'rindex2-rindex3', 'rindex3-rindextip', 'rmiddle1-rmiddle2', 'rmiddle2-rmiddle3', 'rmiddle3-rmiddletip', 'rring1-rring2', 'rring2-rring3', 'rring3-rringtip

### Load reference walking motion

In [4]:
motion_name = 'smplrig_walk_optimized_recon_local'
pkl_path = '../data/%s.pkl'%(motion_name)
with open(pkl_path, 'rb') as f:
    data_reference = pickle.load(f)
print ("[%s] loaded."%(pkl_path))

[../data/smplrig_walk_optimized_recon_local.pkl] loaded.


In [5]:
KEY_BODY_NAMES = ["right_ankle", "left_ankle", "right_wrist", "left_wrist"]
key_body_ids= np.array(env.get_idxs_body(KEY_BODY_NAMES))

In [6]:
### PID
ctrl_ranges     = env.model.actuator_ctrlrange
n_ctrl          = env.model.nu
hidden_dim = 512

world_model = GaussianPolicy(
    input_dim=obs_dim+action_dim,
    output_dim=obs_dim-12, # exclude key body pos
    hidden_dim=hidden_dim,
    is_deterministic=True,
)

RunningMeanStd:  134
RunningMeanStd:  85


In [7]:
class MPCDataset(Dataset):
    def __init__(self, obs, act, horizon=200):
        # self.obs = obs
        # self.obs = (self.obs - self.obs.mean(dim=0)) / self.obs.std(dim=0)
        self.act = act[:]# * env.model.actuator_gear[:, 0]
        self.obs = obs[:]
        self.horizon = horizon

        assert self.obs.shape[0] == self.act.shape[0]

    def __len__(self):
        return self.obs.shape[0]

    def __getitem__(self,idx):
        return self.obs[idx], self.act[idx]

In [8]:
train_dataset = MPCDataset(obs, action, horizon)
train_dataloader = DataLoader(train_dataset, batch_size=4000, shuffle=True)
# test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

act_scale = train_dataset.act.max(dim=0).values - train_dataset.act.min(dim=0).values
obs_scale = train_dataset.obs.max(dim=0).values - train_dataset.obs.min(dim=0).values
# obs_scale[0:5] = 0

In [9]:
def linear_decay_lr(start_lr, end_lr, epoch, total_epochs):
    """
    선형적으로 감소하는 learning rate를 계산하는 함수
    
    :param start_lr: 시작 learning rate
    :param end_lr: 끝 learning rate
    :param epoch: 현재 에포크
    :param total_epochs: 총 에포크 수
    :return: 선형 감소된 learning rate
    """
    return end_lr + (start_lr - end_lr) * (1 - epoch / total_epochs)


In [10]:
def obs_to_pos_vel(obs):
    height = obs[0:1]
    root_rot = obs[1:5]
    root_vel = obs[5:8]
    root_ang_vel = obs[8:11]
    dof_pos = obs[11:48]
    dof_vel = obs[48:85]
    local_key_pos = obs[85:97]

    return height, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, local_key_pos
def get_next_state_sample(obses, acts, noise_rate = 0.0):
    next_states_np = np.empty((0,85))
    acts_np = np.empty((0,37))
    for obs, act in zip(obses, acts):
        obs += obs_scale * torch.randn_like(obs)

        height, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, local_key_pos = obs_to_pos_vel(obs)

        env.data.qvel = np.concatenate((root_vel, root_ang_vel, dof_vel))
        state = np.concatenate((np.zeros((2,)),height,root_rot,dof_pos))
        # state += np.array(obs_scale * torch.randn_like(state) * noise_rate)
        
        env.forward(state)
        
        act += act_scale * torch.randn_like(act) * noise_rate
        act = np.array(act)
        env.step(act)

        next_qpos = env.data.qpos.copy()
        next_qvel = env.data.qpos.copy()
        key_body_pos = env.data.xpos[key_body_ids].copy()
        next_state = np.concatenate((next_qpos[2:7],next_qvel[0:6],next_qpos[7:44],next_qvel[6:43])) # ,key_body_pos.reshape(12)
        next_states_np=np.concatenate((next_states_np,[next_state]),axis=0)
        acts_np=np.concatenate((acts_np,[act]),axis=0)

    return obses, acts_np, next_states_np
    

In [11]:
def loss_function(GT_y,pred_y):
    l = torch.nn.L1Loss()
    loss = l(pred_y, torch.tensor(GT_y))
    return loss

In [12]:
num_epoch = 1000
optimizer = optim.Adam(world_model.parameters(), lr=1e-4, eps=1e-8, betas=(0.9, 0.95), weight_decay=0.01)

In [13]:
# policy.action_mean_std.train()
for epoch in range(num_epoch):
    with tqdm(train_dataloader, unit="batch") as tepoch:
        for x, y in tepoch:
            
            tepoch.set_description(f"Epoch {epoch+1}")
            optimizer.zero_grad()
            
            obses, acts_np, next_states_np = get_next_state_sample(x, y)
            input = torch.cat((obses, torch.tensor(acts_np,dtype=torch.float32)),axis=1)
            loss = loss_function(next_states_np, world_model(input)[0])
            # loss,loss_info=MotionVAEPolicy.loss_total(x=x, c=t, prev_x=prev_x, y=y, beta=beta)
            loss.backward()
            optimizer.step()
            
            tepoch.set_postfix(recon_loss=float(loss))

Epoch 1:   0%|          | 0/1 [00:00<?, ?batch/s]

Epoch 1: 100%|██████████| 1/1 [00:00<00:00, 12.54batch/s, recon_loss=0.831]
Epoch 2: 100%|██████████| 1/1 [00:00<00:00, 15.35batch/s, recon_loss=0.764]
Epoch 3: 100%|██████████| 1/1 [00:00<00:00, 15.88batch/s, recon_loss=0.694]
Epoch 4: 100%|██████████| 1/1 [00:00<00:00, 15.87batch/s, recon_loss=0.63]
Epoch 5: 100%|██████████| 1/1 [00:00<00:00, 16.40batch/s, recon_loss=0.575]
Epoch 6: 100%|██████████| 1/1 [00:00<00:00, 16.40batch/s, recon_loss=0.528]
Epoch 7: 100%|██████████| 1/1 [00:00<00:00, 16.73batch/s, recon_loss=0.49]
Epoch 8: 100%|██████████| 1/1 [00:00<00:00, 16.44batch/s, recon_loss=0.46]
Epoch 9: 100%|██████████| 1/1 [00:00<00:00, 16.01batch/s, recon_loss=0.432]
Epoch 10: 100%|██████████| 1/1 [00:00<00:00, 15.97batch/s, recon_loss=0.408]
Epoch 11: 100%|██████████| 1/1 [00:00<00:00, 16.50batch/s, recon_loss=0.386]
Epoch 12: 100%|██████████| 1/1 [00:00<00:00, 16.27batch/s, recon_loss=0.366]
Epoch 13: 100%|██████████| 1/1 [00:00<00:00, 13.54batch/s, recon_loss=0.348]
Epoch 14: 1

### Test pretrain model

In [14]:
# Set which joints to control
ctrl_joint_names = env.ctrl_joint_names # <==
joint_idxs_fwd = env.get_idxs_fwd(joint_names=ctrl_joint_names)
joint_idxs_jac = env.get_idxs_jac(joint_names=ctrl_joint_names)
q_ctrl_init = env.get_qpos_joints(ctrl_joint_names)
n_ctrl_joint = len(ctrl_joint_names)

In [15]:
LOAD = False
file_path = "240501 worldmodel.pth"
if LOAD:
    world_model.load_state_dict((torch.load(file_path)))

In [16]:
world_model.eval()
world_model.running_mean_std.eval()

RunningMeanStd()

In [17]:
# Initialize MuJoCo viewer
env.init_viewer(viewer_title='SMPL',viewer_width=1200,viewer_height=800,
                viewer_hide_menus=False)
env.update_viewer(azimuth=152,distance=3.0,elevation=-20,lookat=[0.02,-0.03,1.2])
env.reset()
image_list = []
for t in np.arange(0,1):
    # assign the state in dataset (TODO : implement with `obs` variable)
    time_offset = 10
    rp = root_pos[t+time_offset]
    rr = root_rot[t+time_offset]
    rv = root_vel[t+time_offset]
    rw = root_ang_vel[t+time_offset]
    qpos = dof_pos[t+time_offset]
    qvel = dof_vel[t+time_offset]
    env.data.qvel = torch.cat((rv, rw, qvel))
    # env.data.qvel = np.zeros_like(env.data.qvel)
    env.forward(torch.cat((rp, rr, qpos)))
    env.data.qacc = np.zeros_like(env.data.qacc)

    prev_state = None
    for h in range(300-time_offset):
        # env.forward(torch.cat((rp[0:2],root_pos[t+h,2:3]+0.2, root_rot[t+h])),joint_idxs=np.arange(0,7))
        ctrl = action[t+h+time_offset]
        state = torch.from_numpy(np.concatenate((env.get_p_body('base')[2:3], r2quat(env.get_R_body('base')), env.get_qvel_joint('base')[0:3], env.get_qvel_joint('base')[3:6], env.data.qpos[env.rev_joint_idxs+6].copy(), env.data.qvel[env.rev_joint_idxs+5].copy(), (env.data.xpos[key_body_ids]-env.get_p_body('base')).reshape(12)), axis=-1)).type(torch.float32)
        phase = min(t+h+time_offset,314)
        phase = torch.tensor([phase/314.], dtype=torch.float32)
        # next_state = torch.from_numpy(np.concatenate((root_pos[next_idx,2:],root_rot[next_idx],root_vel[next_idx],root_ang_vel[next_idx],dof_pos[next_idx],dof_vel[next_idx],local_key_pos[next_idx]), axis=-1))
        # ext_state = torch.cat((state, next_state))
        if prev_state is None:
            prev_state = state
        # ext_state = MotionVAEPolicy.running_mean_std(ext_state)
        
        input = torch.cat((state, ctrl))
        pred_state = world_model(input)[0]
        pred_next_state = torch.cat((torch.zeros((2,)),pred_state[0:5],pred_state[11:48]))
        pred_next_state = pred_next_state.detach().numpy()
        env.forward(pred_next_state)
        if env.loop_every(tick_every=1):
            # Plot world frame
            env.plot_T(p=np.zeros(3),R=np.eye(3,3),
                    PLOT_AXIS=True,axis_len=0.5,axis_width=0.005)
            env.plot_T(p=np.array([0,0,0.5]),R=np.eye(3,3),
                    PLOT_AXIS=False,label="tick:[%d]"%(env.tick))

            # Plot contact information
            env.render()

env.close_viewer()

In [18]:
with media.VideoWriter(
    "behavior cloning eval.mp4", shape=(480, 640), fps=5) as w:
    for image in image_list:
        w.add_image(image)

In [19]:
SAVE = True
file_path = "240501 worldmodel.pth"
if SAVE:
    torch.save(world_model.state_dict(), file_path)

In [20]:
policy.state_dict()

NameError: name 'policy' is not defined