In [None]:
%matplotlib inline
from matplotlib import cm
import matplotlib.pyplot as plt
import numpy as np

reward_seqs = [0.0, 0.18974148, 0.1992907, 0.350656, 0.35882896, 0.39109656, \
 0.39486712, 0.39511558, 0.46375644, 0.47071177, 0.5655106, 0.70239437, \
 0.70598173, 0.72608304, 0.7819717, 0.79253316, 0.8283197, 0.9700098, \
 0.99150366, 1.0] 
fig, axs = plt.subplots(1, 3)
fig.set_size_inches(24, 4.8)
for i in range(3):
    ax = axs[i]
    ax.scatter(np.random.randint(5, size=20), 
        np.random.randint(5, size=20),
        c=reward_seqs, cmap=cm.jet)
    # ax.scatter(init_pose_seqs[idx[-best_k:], i, self.gripper_mid_pt, 0], 
    #     init_pose_seqs[idx[-best_k:], i, self.gripper_mid_pt, 2], marker='v', 
    #     c=reward_seqs[idx[-best_k:]])
    # ax.scatter(init_pose_seqs[idx[:-best_k], i, self.gripper_mid_pt, 0], 
    #     init_pose_seqs[idx[:-best_k], i, self.gripper_mid_pt, 2], 
    #     c=reward_seqs[idx[:-best_k]])
    # ax.scatter(selected[:, i, 0], selected[:, i, 2], c='r')
    # ax.scatter(others[:, i, 0], others[:, i, 2], color=[0.0,0.3,0.7,0.3])

    ax.set_title(f"GRIP {i+1}")
    ax.set_xlabel('x coordinate')
    ax.set_ylabel('z coordinate')

color_map = cm.ScalarMappable(cmap=cm.jet)
color_map.set_array(reward_seqs)
plt.colorbar(color_map, ax=axs)

# plt.savefig(path)
plt.show()

In [None]:
# Covariance matrix adaptation evolution strategy (CMA-ES)
def optimize_action_CMA_ES(   
    self,
    init_pose_seqs,
    act_seqs,
    reward_seqs,    # [n_sample]
    best_k_ratio=0.05
):
    best_k = max(3, int(init_pose_seqs.shape[0] * best_k_ratio))
    m = np.mean(init_pose_seqs)
    n_samples = init_pose_seqs.shape[0]
    C = np.eye(n_samples)
    p_sigma = 0
    p_c = 0
    idx = np.argsort(reward_seqs)
    print(f"Selected top reward seqs: {reward_seqs[idx[-best_k:]]}")
    # print(f"Selected top init pose seqs: {init_pose_seqs[idx[-best_k:], :, self.gripper_mid_pt, :7]}")

    self.visualize_sampled_init_pos(init_pose_seqs, reward_seqs, idx, \
        os.path.join(self.rollout_path, f'plot_cem_s{self.sample_iter_cur}_o{self.opt_iter_cur}'))

    init_pose_seqs_sample = []
    act_seqs_sample = []
    for i in range(best_k, 0, -1):
        init_pose_seq = init_pose_seqs[idx[-i]]
        # print(f"Selected init pose seq: {init_pose_seq[:, self.gripper_mid_pt, :7]}")
        init_pose_seqs_sample.append(init_pose_seq)
        act_seqs_sample.append(act_seqs[idx[-i]])
        j = 1

        if i > 1:
            n_samples = int(init_pose_seqs.shape[0] / (2**i))
        else:
            n_samples = init_pose_seqs.shape[0] - len(init_pose_seqs_sample) + 1
        
        while j < n_samples:
            mid_point_seq, angle_seq = self.get_center_and_rot_from_pose(init_pose_seq)
            init_pose_seq_sample = []
            for k in range(init_pose_seq.shape[0]):
                p_noise = np.clip(np.array([0, 0, np.random.randn()*0.03]), a_max=0.1, a_min=-0.1)
                rot_noise = np.clip(np.random.randn() * np.pi / 36, a_max=0.1, a_min=-0.1)
            
                new_mid_point = mid_point_seq[k, :3] + p_noise
                new_angle = angle_seq[k] + rot_noise
                init_pose = self.get_pose(new_mid_point, new_angle)
                init_pose_seq_sample.append(init_pose)

                # import pdb; pdb.set_trace()

            init_pose_seq_sample = np.stack(init_pose_seq_sample)
            act_seq_sample = self.get_action_seq_from_pose(init_pose_seq_sample)

            init_pose_seqs_sample.append(init_pose_seq_sample)
            act_seqs_sample.append(act_seq_sample)
            
            # print(f"Selected init pose seq: {init_pose_seq_sample[:, self.gripper_mid_pt, :7]}")

            j += 1

    # import pdb; pdb.set_trace()
    init_pose_seqs_sample = np.stack(init_pose_seqs_sample)
    act_seqs_sample = np.stack(act_seqs_sample)

    return init_pose_seqs_sample, act_seqs_sample

In [None]:
def optimize_action_MPPI(   # Model-Predictive Path Integral (MPPI)
    self,
    init_pose_seqs,
    act_seqs,       # [n_sample, -1, action_dim]
    reward_seqs     # [n_sample]
):
    print(f"reward_seqs: {reward_seqs}")
    # [n_sample, 1, 1]
    # reward_seqs_exp = np.exp(self.reward_weight * (reward_seqs - np.mean(reward_seqs)))
    reward_seqs = (reward_seqs - np.mean(reward_seqs)) / np.var(reward_seqs)
    reward_seqs_norm = reward_seqs / np.linalg.norm(reward_seqs)
    reward_seqs_exp = np.exp(self.reward_weight * reward_seqs_norm)
    print(f"reward_seqs_exp: {reward_seqs_exp}")

    # [-1, action_dim]
    eps = 1e-8
    mid_point_x = np.full((self.n_sample, init_pose_seqs.shape[1]), self.mid_point[0])
    
    rot_noise_seqs = np.arccos((init_pose_seqs[:, :, self.gripper_mid_pt, 0] - mid_point_x) / self.sample_radius)
    print(rot_noise_seqs)
    print(reward_seqs_exp.reshape(-1, 1))
    print(reward_seqs_exp.reshape(-1, 1) * rot_noise_seqs)

    rot_noise_seq = np.sum(reward_seqs_exp.reshape(-1, 1) * rot_noise_seqs, axis=0) / (np.sum(reward_seqs_exp) + eps)
    # act_seq = np.sum(reward_seqs_exp.reshape(-1, 1, 1, 1) * act_seqs, axis=0) / (np.sum(reward_seqs_exp) + eps)

    print(f"rot_noise_seq: {rot_noise_seq}")

    init_pose_seq = []
    act_seq = []
    for rot_noise in rot_noise_seq:
        init_pose_seq.append(self.get_pose(self.mid_point, rot_noise))
        act_seq.append(self.get_action_seq(rot_noise))

    init_pose_seq = np.stack(init_pose_seq)
    act_seq = np.stack(act_seq)
    
    # [-1, action_dim]
    return init_pose_seq, act_seq

In [None]:
p state_cur
(Pdb) tensor([[[0.3904, 0.0959, 0.4351],
         [0.6174, 0.1879, 0.5990],
         [0.6135, 0.0670, 0.3950],
         ...,
         [0.6636, 0.1940, 0.3645],
         [0.6636, 0.2120, 0.3645],
         [0.6636, 0.2300, 0.3645]],

        [[0.4352, 0.1039, 0.4383],
         [0.6189, 0.0736, 0.6159],
         [0.6146, 0.1853, 0.4382],
         ...,
         [0.6561, 0.1940, 0.3712],
         [0.6561, 0.2120, 0.3712],
         [0.6561, 0.2300, 0.3712]],

        [[0.4693, 0.1672, 0.5154],
         [0.6146, 0.0767, 0.3904],
         [0.6201, 0.0772, 0.6049],
         ...,
         [0.6487, 0.1940, 0.3779],
         [0.6487, 0.2120, 0.3779],
         [0.6487, 0.2300, 0.3779]],

        [[0.4785, 0.0970, 0.4615],
         [0.6114, 0.1729, 0.6151],
         [0.4019, 0.1768, 0.6142],
         ...,
         [0.6413, 0.1940, 0.3846],
         [0.6413, 0.2120, 0.3846],
         [0.6413, 0.2300, 0.3846]]])
p init_pose_seqs
(Pdb) array([[[[0.66649942, 0.05      , 0.7162913 , 1.        , 0.        ,
          0.        , 0.        , 0.33350058, 0.05      , 0.3433144 ,
          1.        , 0.        , 0.        , 0.        ],
         [0.66649942, 0.068     , 0.7162913 , 1.        , 0.        ,
          0.        , 0.        , 0.33350058, 0.068     , 0.3433144 ,
          1.        , 0.        , 0.        , 0.        ],
         [0.66649942, 0.086     , 0.7162913 , 1.        , 0.        ,
          0.        , 0.        , 0.33350058, 0.086     , 0.3433144 ,
          1.        , 0.        , 0.        , 0.        ],
         [0.66649942, 0.104     , 0.7162913 , 1.        , 0.        ,
          0.        , 0.        , 0.33350058, 0.104     , 0.3433144 ,
          1.        , 0.        , 0.        , 0.        ],
         [0.66649942, 0.122     , 0.7162913 , 1.        , 0.        ,
          0.        , 0.        , 0.33350058, 0.122     , 0.3433144 ,
          1.        , 0.        , 0.        , 0.        ],
         [0.66649942, 0.14      , 0.7162913 , 1.        , 0.        ,
          0.        , 0.        , 0.33350058, 0.14      , 0.3433144 ,
          1.        , 0.        , 0.        , 0.        ],
         [0.66649942, 0.158     , 0.7162913 , 1.        , 0.        ,
          0.        , 0.        , 0.33350058, 0.158     , 0.3433144 ,
          1.        , 0.        , 0.        , 0.        ],
         [0.66649942, 0.176     , 0.7162913 , 1.        , 0.        ,
          0.        , 0.        , 0.33350058, 0.176     , 0.3433144 ,
          1.        , 0.        , 0.        , 0.        ],
         [0.66649942, 0.194     , 0.7162913 , 1.        , 0.        ,
          0.        , 0.        , 0.33350058, 0.194     , 0.3433144 ,
          1.        , 0.        , 0.        , 0.        ],
         [0.66649942, 0.212     , 0.7162913 , 1.        , 0.        ,
          0.        , 0.        , 0.33350058, 0.212     , 0.3433144 ,
          1.        , 0.        , 0.        , 0.        ],
         [0.66649942, 0.23      , 0.7162913 , 1.        , 0.        ,
          0.        , 0.        , 0.33350058, 0.23      , 0.3433144 ,
          1.        , 0.        , 0.        , 0.        ]]],


       [[[0.27943408, 0.05      , 0.60939304, 1.        , 0.        ,
          0.        , 0.        , 0.72056592, 0.05      , 0.37401524,
          1.        , 0.        , 0.        , 0.        ],
         [0.27943408, 0.068     , 0.60939304, 1.        , 0.        ,
          0.        , 0.        , 0.72056592, 0.068     , 0.37401524,
          1.        , 0.        , 0.        , 0.        ],
         [0.27943408, 0.086     , 0.60939304, 1.        , 0.        ,
          0.        , 0.        , 0.72056592, 0.086     , 0.37401524,
          1.        , 0.        , 0.        , 0.        ],
         [0.27943408, 0.104     , 0.60939304, 1.        , 0.        ,
          0.        , 0.        , 0.72056592, 0.104     , 0.37401524,
          1.        , 0.        , 0.        , 0.        ],
         [0.27943408, 0.122     , 0.60939304, 1.        , 0.        ,
          0.        , 0.        , 0.72056592, 0.122     , 0.37401524,
          1.        , 0.        , 0.        , 0.        ],
         [0.27943408, 0.14      , 0.60939304, 1.        , 0.        ,
          0.        , 0.        , 0.72056592, 0.14      , 0.37401524,
          1.        , 0.        , 0.        , 0.        ],
         [0.27943408, 0.158     , 0.60939304, 1.        , 0.        ,
          0.        , 0.        , 0.72056592, 0.158     , 0.37401524,
          1.        , 0.        , 0.        , 0.        ],
         [0.27943408, 0.176     , 0.60939304, 1.        , 0.        ,
          0.        , 0.        , 0.72056592, 0.176     , 0.37401524,
          1.        , 0.        , 0.        , 0.        ],
         [0.27943408, 0.194     , 0.60939304, 1.        , 0.        ,
          0.        , 0.        , 0.72056592, 0.194     , 0.37401524,
          1.        , 0.        , 0.        , 0.        ],
         [0.27943408, 0.212     , 0.60939304, 1.        , 0.        ,
          0.        , 0.        , 0.72056592, 0.212     , 0.37401524,
          1.        , 0.        , 0.        , 0.        ],
         [0.27943408, 0.23      , 0.60939304, 1.        , 0.        ,
          0.        , 0.        , 0.72056592, 0.23      , 0.37401524,
          1.        , 0.        , 0.        , 0.        ]]],


       [[[0.25052257, 0.05      , 0.61090853, 1.        , 0.        ,
          0.        , 0.        , 0.74947743, 0.05      , 0.57859701,
          1.        , 0.        , 0.        , 0.        ],
         [0.25052257, 0.068     , 0.61090853, 1.        , 0.        ,
          0.        , 0.        , 0.74947743, 0.068     , 0.57859701,
          1.        , 0.        , 0.        , 0.        ],
         [0.25052257, 0.086     , 0.61090853, 1.        , 0.        ,
          0.        , 0.        , 0.74947743, 0.086     , 0.57859701,
          1.        , 0.        , 0.        , 0.        ],
         [0.25052257, 0.104     , 0.61090853, 1.        , 0.        ,
          0.        , 0.        , 0.74947743, 0.104     , 0.57859701,
          1.        , 0.        , 0.        , 0.        ],
         [0.25052257, 0.122     , 0.61090853, 1.        , 0.        ,
          0.        , 0.        , 0.74947743, 0.122     , 0.57859701,
          1.        , 0.        , 0.        , 0.        ],
         [0.25052257, 0.14      , 0.61090853, 1.        , 0.        ,
          0.        , 0.        , 0.74947743, 0.14      , 0.57859701,
          1.        , 0.        , 0.        , 0.        ],
         [0.25052257, 0.158     , 0.61090853, 1.        , 0.        ,
          0.        , 0.        , 0.74947743, 0.158     , 0.57859701,
          1.        , 0.        , 0.        , 0.        ],
         [0.25052257, 0.176     , 0.61090853, 1.        , 0.        ,
          0.        , 0.        , 0.74947743, 0.176     , 0.57859701,
          1.        , 0.        , 0.        , 0.        ],
         [0.25052257, 0.194     , 0.61090853, 1.        , 0.        ,
          0.        , 0.        , 0.74947743, 0.194     , 0.57859701,
          1.        , 0.        , 0.        , 0.        ],
         [0.25052257, 0.212     , 0.61090853, 1.        , 0.        ,
          0.        , 0.        , 0.74947743, 0.212     , 0.57859701,
          1.        , 0.        , 0.        , 0.        ],
         [0.25052257, 0.23      , 0.61090853, 1.        , 0.        ,
          0.        , 0.        , 0.74947743, 0.23      , 0.57859701,
          1.        , 0.        , 0.        , 0.        ]]],


       [[[0.71615574, 0.05      , 0.67165141, 1.        , 0.        ,
          0.        , 0.        , 0.28384426, 0.05      , 0.42044075,
          1.        , 0.        , 0.        , 0.        ],
         [0.71615574, 0.068     , 0.67165141, 1.        , 0.        ,
          0.        , 0.        , 0.28384426, 0.068     , 0.42044075,
          1.        , 0.        , 0.        , 0.        ],
         [0.71615574, 0.086     , 0.67165141, 1.        , 0.        ,
          0.        , 0.        , 0.28384426, 0.086     , 0.42044075,
          1.        , 0.        , 0.        , 0.        ],
         [0.71615574, 0.104     , 0.67165141, 1.        , 0.        ,
          0.        , 0.        , 0.28384426, 0.104     , 0.42044075,
          1.        , 0.        , 0.        , 0.        ],
         [0.71615574, 0.122     , 0.67165141, 1.        , 0.        ,
          0.        , 0.        , 0.28384426, 0.122     , 0.42044075,
          1.        , 0.        , 0.        , 0.        ],
         [0.71615574, 0.14      , 0.67165141, 1.        , 0.        ,
          0.        , 0.        , 0.28384426, 0.14      , 0.42044075,
          1.        , 0.        , 0.        , 0.        ],
         [0.71615574, 0.158     , 0.67165141, 1.        , 0.        ,
          0.        , 0.        , 0.28384426, 0.158     , 0.42044075,
          1.        , 0.        , 0.        , 0.        ],
         [0.71615574, 0.176     , 0.67165141, 1.        , 0.        ,
          0.        , 0.        , 0.28384426, 0.176     , 0.42044075,
          1.        , 0.        , 0.        , 0.        ],
         [0.71615574, 0.194     , 0.67165141, 1.        , 0.        ,
          0.        , 0.        , 0.28384426, 0.194     , 0.42044075,
          1.        , 0.        , 0.        , 0.        ],
         [0.71615574, 0.212     , 0.67165141, 1.        , 0.        ,
          0.        , 0.        , 0.28384426, 0.212     , 0.42044075,
          1.        , 0.        , 0.        , 0.        ],
         [0.71615574, 0.23      , 0.67165141, 1.        , 0.        ,
          0.        , 0.        , 0.28384426, 0.23      , 0.42044075,
          1.        , 0.        , 0.        , 0.        ]]]])
p act_seqs
(Pdb) array([[[[-0.33299884,  0.        , -0.3729769 , ...,  0.        ,
           0.        ,  0.        ],
         [-0.33299884,  0.        , -0.3729769 , ...,  0.        ,
           0.        ,  0.        ],
         [-0.33299884,  0.        , -0.3729769 , ...,  0.        ,
           0.        ,  0.        ],
         ...,
         [ 0.33299884,  0.        ,  0.3729769 , ...,  0.        ,
           0.        ,  0.        ],
         [ 0.33299884,  0.        ,  0.3729769 , ...,  0.        ,
           0.        ,  0.        ],
         [ 0.33299884,  0.        ,  0.3729769 , ...,  0.        ,
           0.        ,  0.        ]]],


       [[[ 0.44113183,  0.        , -0.23537779, ...,  0.        ,
           0.        ,  0.        ],
         [ 0.44113183,  0.        , -0.23537779, ...,  0.        ,
           0.        ,  0.        ],
         [ 0.44113183,  0.        , -0.23537779, ...,  0.        ,
           0.        ,  0.        ],
         ...,
         [-0.44113183,  0.        ,  0.23537779, ...,  0.        ,
           0.        ,  0.        ],
         [-0.44113183,  0.        ,  0.23537779, ...,  0.        ,
           0.        ,  0.        ],
         [-0.44113183,  0.        ,  0.23537779, ...,  0.        ,
           0.        ,  0.        ]]],


       [[[ 0.49895486,  0.        , -0.03231152, ...,  0.        ,
           0.        ,  0.        ],
         [ 0.49895486,  0.        , -0.03231152, ...,  0.        ,
           0.        ,  0.        ],
         [ 0.49895486,  0.        , -0.03231152, ...,  0.        ,
           0.        ,  0.        ],
         ...,
         [-0.49895486,  0.        ,  0.03231152, ...,  0.        ,
           0.        ,  0.        ],
         [-0.49895486,  0.        ,  0.03231152, ...,  0.        ,
           0.        ,  0.        ],
         [-0.49895486,  0.        ,  0.03231152, ...,  0.        ,
           0.        ,  0.        ]]],


       [[[-0.43231148,  0.        , -0.25121066, ...,  0.        ,
           0.        ,  0.        ],
         [-0.43231148,  0.        , -0.25121066, ...,  0.        ,
           0.        ,  0.        ],
         [-0.43231148,  0.        , -0.25121066, ...,  0.        ,
           0.        ,  0.        ],
         ...,
         [ 0.43231148,  0.        ,  0.25121066, ...,  0.        ,
           0.        ,  0.        ],
         [ 0.43231148,  0.        ,  0.25121066, ...,  0.        ,
           0.        ,  0.        ],
         [ 0.43231148,  0.        ,  0.25121066, ...,  0.        ,
           0.        ,  0.        ]]]], dtype=float32)


tensor([[[[0.4789, 0.0982, 0.4606],
          [0.6069, 0.1723, 0.6136],
          [0.4146, 0.1689, 0.6082],
          ...,
          [0.6040, 0.0731, 0.5341],
          [0.4882, 0.1653, 0.5202],
          [0.5377, 0.0916, 0.4424]],

         [[0.4792, 0.0997, 0.4602],
          [0.6037, 0.1714, 0.6131],
          [0.4164, 0.1745, 0.6091],
          ...,
          [0.6031, 0.0738, 0.5344],
          [0.4900, 0.1636, 0.5195],
          [0.5364, 0.0932, 0.4433]],

         [[0.4790, 0.1003, 0.4603],
          [0.6027, 0.1735, 0.6120],
          [0.4173, 0.1744, 0.6101],
          ...,
          [0.6024, 0.0741, 0.5343],
          [0.4920, 0.1626, 0.5186],
          [0.5355, 0.0947, 0.4442]],

         ...,

         [[0.4956, 0.1236, 0.4792],
          [0.5896, 0.1728, 0.5481],
          [0.4187, 0.1783, 0.6124],
          ...,
          [0.5572, 0.0706, 0.5241],
          [0.5250, 0.1806, 0.4806],
          [0.5294, 0.0970, 0.4505]],

         [[0.4947, 0.1283, 0.4772],
          [0.5917, 0.1724, 0.5485],
          [0.4190, 0.1783, 0.6132],
          ...,
          [0.5551, 0.0705, 0.5244],
          [0.5261, 0.1810, 0.4800],
          [0.5288, 0.0972, 0.4512]],

         [[0.4943, 0.1307, 0.4767],
          [0.5931, 0.1707, 0.5468],
          [0.4194, 0.1791, 0.6142],
          ...,
          [0.5536, 0.0708, 0.5240],
          [0.5281, 0.1799, 0.4793],
          [0.5281, 0.0972, 0.4519]]],


        [[[0.4789, 0.0982, 0.4606],
          [0.6069, 0.1723, 0.6136],
          [0.4146, 0.1689, 0.6082],
          ...,
          [0.6040, 0.0731, 0.5341],
          [0.4882, 0.1653, 0.5202],
          [0.5377, 0.0916, 0.4424]],

         [[0.4792, 0.0997, 0.4602],
          [0.6037, 0.1714, 0.6131],
          [0.4164, 0.1745, 0.6091],
          ...,
          [0.6031, 0.0738, 0.5344],
          [0.4900, 0.1636, 0.5195],
          [0.5364, 0.0932, 0.4433]],

         [[0.4790, 0.1003, 0.4603],
          [0.6027, 0.1735, 0.6120],
          [0.4173, 0.1744, 0.6101],
          ...,
          [0.6024, 0.0741, 0.5343],
          [0.4920, 0.1626, 0.5186],
          [0.5355, 0.0947, 0.4442]],

         ...,

         [[0.4681, 0.0978, 0.4721],
          [0.5999, 0.1676, 0.6069],
          [0.4398, 0.1724, 0.6228],
          ...,
          [0.5550, 0.0723, 0.5307],
          [0.5345, 0.1618, 0.4889],
          [0.4963, 0.1052, 0.4540]],

         [[0.4674, 0.0982, 0.4726],
          [0.5997, 0.1672, 0.6075],
          [0.4404, 0.1700, 0.6233],
          ...,
          [0.5531, 0.0727, 0.5305],
          [0.5371, 0.1599, 0.4860],
          [0.4958, 0.1065, 0.4543]],

         [[0.4668, 0.0984, 0.4732],
          [0.5997, 0.1681, 0.6078],
          [0.4411, 0.1675, 0.6236],
          ...,
          [0.5516, 0.0727, 0.5302],
          [0.5388, 0.1584, 0.4841],
          [0.4953, 0.1075, 0.4549]]],


        [[[0.4789, 0.0982, 0.4606],
          [0.6069, 0.1723, 0.6136],
          [0.4146, 0.1689, 0.6082],
          ...,
          [0.6040, 0.0731, 0.5341],
          [0.4882, 0.1653, 0.5202],
          [0.5377, 0.0916, 0.4424]],

         [[0.4792, 0.0997, 0.4602],
          [0.6037, 0.1714, 0.6131],
          [0.4164, 0.1745, 0.6091],
          ...,
          [0.6031, 0.0738, 0.5344],
          [0.4900, 0.1636, 0.5195],
          [0.5364, 0.0932, 0.4433]],

         [[0.4790, 0.1003, 0.4603],
          [0.6027, 0.1735, 0.6120],
          [0.4173, 0.1744, 0.6101],
          ...,
          [0.6024, 0.0741, 0.5343],
          [0.4920, 0.1626, 0.5186],
          [0.5355, 0.0947, 0.4442]],

         ...,

         [[0.4743, 0.1047, 0.4700],
          [0.5504, 0.1890, 0.6316],
          [0.4773, 0.2046, 0.5845],
          ...,
          [0.5410, 0.0718, 0.5191],
          [0.5196, 0.1632, 0.4790],
          [0.5183, 0.1054, 0.4442]],

         [[0.4744, 0.1053, 0.4706],
          [0.5528, 0.1881, 0.6316],
          [0.4768, 0.2029, 0.5820],
          ...,
          [0.5399, 0.0715, 0.5193],
          [0.5212, 0.1625, 0.4775],
          [0.5180, 0.1050, 0.4436]],

         [[0.4743, 0.1054, 0.4712],
          [0.5546, 0.1875, 0.6316],
          [0.4758, 0.2018, 0.5797],
          ...,
          [0.5378, 0.0716, 0.5200],
          [0.5225, 0.1626, 0.4758],
          [0.5178, 0.1048, 0.4432]]],


        [[[0.4789, 0.0982, 0.4606],
          [0.6069, 0.1723, 0.6136],
          [0.4146, 0.1689, 0.6082],
          ...,
          [0.6040, 0.0731, 0.5341],
          [0.4882, 0.1653, 0.5202],
          [0.5377, 0.0916, 0.4424]],

         [[0.4792, 0.0997, 0.4602],
          [0.6037, 0.1714, 0.6131],
          [0.4164, 0.1745, 0.6091],
          ...,
          [0.6031, 0.0738, 0.5344],
          [0.4900, 0.1636, 0.5195],
          [0.5364, 0.0932, 0.4433]],

         [[0.4790, 0.1003, 0.4603],
          [0.6027, 0.1735, 0.6120],
          [0.4173, 0.1744, 0.6101],
          ...,
          [0.6024, 0.0741, 0.5343],
          [0.4920, 0.1626, 0.5186],
          [0.5355, 0.0947, 0.4442]],

         ...,

         [[0.4889, 0.1076, 0.4765],
          [0.5366, 0.1649, 0.6158],
          [0.4189, 0.1795, 0.6114],
          ...,
          [0.5496, 0.0716, 0.5203],
          [0.5205, 0.1799, 0.4911],
          [0.5253, 0.1166, 0.4386]],

         [[0.4894, 0.1071, 0.4769],
          [0.5364, 0.1644, 0.6159],
          [0.4190, 0.1795, 0.6126],
          ...,
          [0.5467, 0.0718, 0.5213],
          [0.5221, 0.1788, 0.4893],
          [0.5249, 0.1191, 0.4384]],

         [[0.4890, 0.1082, 0.4778],
          [0.5357, 0.1618, 0.6171],
          [0.4193, 0.1788, 0.6134],
          ...,
          [0.5456, 0.0716, 0.5209],
          [0.5225, 0.1779, 0.4881],
          [0.5238, 0.1244, 0.4371]]]])

tensor([-0.0516, -0.0286, -0.0377, -0.0453])

In [None]:
p state_cur
(Pdb) tensor([[[0.3904, 0.0959, 0.4351],
         [0.6174, 0.1879, 0.5990],
         [0.6135, 0.0670, 0.3950],
         ...,
         [0.6636, 0.1940, 0.3645],
         [0.6636, 0.2120, 0.3645],
         [0.6636, 0.2300, 0.3645]],

        [[0.4352, 0.1039, 0.4383],
         [0.6189, 0.0736, 0.6159],
         [0.6146, 0.1853, 0.4382],
         ...,
         [0.6561, 0.1940, 0.3712],
         [0.6561, 0.2120, 0.3712],
         [0.6561, 0.2300, 0.3712]],

        [[0.4693, 0.1672, 0.5154],
         [0.6146, 0.0767, 0.3904],
         [0.6201, 0.0772, 0.6049],
         ...,
         [0.6487, 0.1940, 0.3779],
         [0.6487, 0.2120, 0.3779],
         [0.6487, 0.2300, 0.3779]],

        [[0.4785, 0.0970, 0.4615],
         [0.6114, 0.1729, 0.6151],
         [0.4019, 0.1768, 0.6142],
         ...,
         [0.6413, 0.1940, 0.3846],
         [0.6413, 0.2120, 0.3846],
         [0.6413, 0.2300, 0.3846]]])
p init_pose_seqs
(Pdb) tensor([[[[0.2794, 0.0500, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.0500, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.0680, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.0680, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.0860, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.0860, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1040, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1040, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1220, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1220, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1400, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1400, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1580, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1580, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1760, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1760, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1940, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1940, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.2120, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.2120, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.2300, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.2300, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000]]],


        [[[0.2794, 0.0500, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.0500, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.0680, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.0680, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.0860, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.0860, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1040, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1040, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1220, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1220, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1400, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1400, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1580, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1580, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1760, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1760, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1940, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1940, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.2120, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.2120, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.2300, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.2300, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000]]],


        [[[0.2794, 0.0500, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.0500, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.0680, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.0680, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.0860, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.0860, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1040, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1040, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1220, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1220, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1400, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1400, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1580, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1580, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1760, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1760, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1940, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1940, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.2120, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.2120, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.2300, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.2300, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000]]],


        [[[0.2794, 0.0500, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.0500, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.0680, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.0680, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.0860, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.0860, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1040, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1040, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1220, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1220, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1400, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1400, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1580, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1580, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1760, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1760, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.1940, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.1940, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.2120, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.2120, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000],
          [0.2794, 0.2300, 0.6094, 1.0000, 0.0000, 0.0000, 0.0000, 0.7206,
           0.2300, 0.3740, 1.0000, 0.0000, 0.0000, 0.0000]]]],
       dtype=torch.float64)
p act_seq
(Pdb) *** NameError: name 'act_seq' is not defined
p act_seqs
(Pdb) tensor([[[[ 0.4411,  0.0000, -0.2354,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.4411,  0.0000, -0.2354,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.4411,  0.0000, -0.2354,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [-0.4411,  0.0000,  0.2354,  ...,  0.0000,  0.0000,  0.0000],
          [-0.4411,  0.0000,  0.2354,  ...,  0.0000,  0.0000,  0.0000],
          [-0.4411,  0.0000,  0.2354,  ...,  0.0000,  0.0000,  0.0000]]],


        [[[ 0.4411,  0.0000, -0.2354,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.4411,  0.0000, -0.2354,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.4411,  0.0000, -0.2354,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [-0.4411,  0.0000,  0.2354,  ...,  0.0000,  0.0000,  0.0000],
          [-0.4411,  0.0000,  0.2354,  ...,  0.0000,  0.0000,  0.0000],
          [-0.4411,  0.0000,  0.2354,  ...,  0.0000,  0.0000,  0.0000]]],


        [[[ 0.4411,  0.0000, -0.2354,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.4411,  0.0000, -0.2354,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.4411,  0.0000, -0.2354,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [-0.4411,  0.0000,  0.2354,  ...,  0.0000,  0.0000,  0.0000],
          [-0.4411,  0.0000,  0.2354,  ...,  0.0000,  0.0000,  0.0000],
          [-0.4411,  0.0000,  0.2354,  ...,  0.0000,  0.0000,  0.0000]]],


        [[[ 0.4411,  0.0000, -0.2354,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.4411,  0.0000, -0.2354,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.4411,  0.0000, -0.2354,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [-0.4411,  0.0000,  0.2354,  ...,  0.0000,  0.0000,  0.0000],
          [-0.4411,  0.0000,  0.2354,  ...,  0.0000,  0.0000,  0.0000],
          [-0.4411,  0.0000,  0.2354,  ...,  0.0000,  0.0000,  0.0000]]]],
       dtype=torch.float64)

tensor([[[[0.4789, 0.0982, 0.4606],
          [0.6069, 0.1723, 0.6136],
          [0.4146, 0.1689, 0.6082],
          ...,
          [0.6040, 0.0731, 0.5341],
          [0.4882, 0.1653, 0.5202],
          [0.5377, 0.0916, 0.4424]],

         [[0.4792, 0.0997, 0.4602],
          [0.6037, 0.1714, 0.6131],
          [0.4164, 0.1745, 0.6091],
          ...,
          [0.6031, 0.0738, 0.5344],
          [0.4900, 0.1636, 0.5195],
          [0.5364, 0.0932, 0.4433]],

         [[0.4790, 0.1003, 0.4603],
          [0.6027, 0.1735, 0.6120],
          [0.4173, 0.1744, 0.6101],
          ...,
          [0.6024, 0.0741, 0.5343],
          [0.4920, 0.1626, 0.5186],
          [0.5355, 0.0947, 0.4442]],

         ...,

         [[0.4720, 0.1044, 0.4686],
          [0.5986, 0.1561, 0.6086],
          [0.4393, 0.1722, 0.6263],
          ...,
          [0.5576, 0.0714, 0.5278],
          [0.5206, 0.1617, 0.4995],
          [0.5231, 0.0970, 0.4546]],

         [[0.4727, 0.1042, 0.4689],
          [0.5992, 0.1569, 0.6080],
          [0.4403, 0.1707, 0.6273],
          ...,
          [0.5556, 0.0712, 0.5278],
          [0.5210, 0.1634, 0.5000],
          [0.5228, 0.0983, 0.4550]],

         [[0.4726, 0.1056, 0.4696],
          [0.5994, 0.1564, 0.6086],
          [0.4413, 0.1705, 0.6287],
          ...,
          [0.5540, 0.0711, 0.5284],
          [0.5220, 0.1639, 0.4997],
          [0.5231, 0.0983, 0.4549]]],


        [[[0.4789, 0.0982, 0.4606],
          [0.6069, 0.1723, 0.6136],
          [0.4146, 0.1689, 0.6082],
          ...,
          [0.6040, 0.0731, 0.5341],
          [0.4882, 0.1653, 0.5202],
          [0.5377, 0.0916, 0.4424]],

         [[0.4792, 0.0997, 0.4602],
          [0.6037, 0.1714, 0.6131],
          [0.4164, 0.1745, 0.6091],
          ...,
          [0.6031, 0.0738, 0.5344],
          [0.4900, 0.1636, 0.5195],
          [0.5364, 0.0932, 0.4433]],

         [[0.4790, 0.1003, 0.4603],
          [0.6027, 0.1735, 0.6120],
          [0.4173, 0.1744, 0.6101],
          ...,
          [0.6024, 0.0741, 0.5343],
          [0.4920, 0.1626, 0.5186],
          [0.5355, 0.0947, 0.4442]],

         ...,

         [[0.4720, 0.1044, 0.4686],
          [0.5986, 0.1561, 0.6086],
          [0.4393, 0.1722, 0.6263],
          ...,
          [0.5576, 0.0714, 0.5278],
          [0.5206, 0.1617, 0.4995],
          [0.5231, 0.0970, 0.4546]],

         [[0.4727, 0.1042, 0.4689],
          [0.5992, 0.1569, 0.6080],
          [0.4403, 0.1707, 0.6273],
          ...,
          [0.5556, 0.0712, 0.5278],
          [0.5210, 0.1634, 0.5000],
          [0.5228, 0.0983, 0.4550]],

         [[0.4726, 0.1056, 0.4696],
          [0.5994, 0.1564, 0.6086],
          [0.4413, 0.1705, 0.6287],
          ...,
          [0.5540, 0.0711, 0.5284],
          [0.5220, 0.1639, 0.4997],
          [0.5231, 0.0983, 0.4549]]],


        [[[0.4789, 0.0982, 0.4606],
          [0.6069, 0.1723, 0.6136],
          [0.4146, 0.1689, 0.6082],
          ...,
          [0.6040, 0.0731, 0.5341],
          [0.4882, 0.1653, 0.5202],
          [0.5377, 0.0916, 0.4424]],

         [[0.4792, 0.0997, 0.4602],
          [0.6037, 0.1714, 0.6131],
          [0.4164, 0.1745, 0.6091],
          ...,
          [0.6031, 0.0738, 0.5344],
          [0.4900, 0.1636, 0.5195],
          [0.5364, 0.0932, 0.4433]],

         [[0.4790, 0.1003, 0.4603],
          [0.6027, 0.1735, 0.6120],
          [0.4173, 0.1744, 0.6101],
          ...,
          [0.6024, 0.0741, 0.5343],
          [0.4920, 0.1626, 0.5186],
          [0.5355, 0.0947, 0.4442]],

         ...,

         [[0.4720, 0.1044, 0.4686],
          [0.5986, 0.1561, 0.6086],
          [0.4393, 0.1722, 0.6263],
          ...,
          [0.5576, 0.0714, 0.5278],
          [0.5206, 0.1617, 0.4995],
          [0.5231, 0.0970, 0.4546]],

         [[0.4727, 0.1042, 0.4689],
          [0.5992, 0.1569, 0.6080],
          [0.4403, 0.1707, 0.6273],
          ...,
          [0.5556, 0.0712, 0.5278],
          [0.5210, 0.1634, 0.5000],
          [0.5228, 0.0983, 0.4550]],

         [[0.4726, 0.1056, 0.4696],
          [0.5994, 0.1564, 0.6086],
          [0.4413, 0.1705, 0.6287],
          ...,
          [0.5540, 0.0711, 0.5284],
          [0.5220, 0.1639, 0.4997],
          [0.5231, 0.0983, 0.4549]]],


        [[[0.4789, 0.0982, 0.4606],
          [0.6069, 0.1723, 0.6136],
          [0.4146, 0.1689, 0.6082],
          ...,
          [0.6040, 0.0731, 0.5341],
          [0.4882, 0.1653, 0.5202],
          [0.5377, 0.0916, 0.4424]],

         [[0.4792, 0.0997, 0.4602],
          [0.6037, 0.1714, 0.6131],
          [0.4164, 0.1745, 0.6091],
          ...,
          [0.6031, 0.0738, 0.5344],
          [0.4900, 0.1636, 0.5195],
          [0.5364, 0.0932, 0.4433]],

         [[0.4790, 0.1003, 0.4603],
          [0.6027, 0.1735, 0.6120],
          [0.4173, 0.1744, 0.6101],
          ...,
          [0.6024, 0.0741, 0.5343],
          [0.4920, 0.1626, 0.5186],
          [0.5355, 0.0947, 0.4442]],

         ...,

         [[0.4720, 0.1044, 0.4686],
          [0.5986, 0.1561, 0.6086],
          [0.4393, 0.1722, 0.6263],
          ...,
          [0.5576, 0.0714, 0.5278],
          [0.5206, 0.1617, 0.4995],
          [0.5231, 0.0970, 0.4546]],

         [[0.4727, 0.1042, 0.4689],
          [0.5992, 0.1569, 0.6080],
          [0.4403, 0.1707, 0.6273],
          ...,
          [0.5556, 0.0712, 0.5278],
          [0.5210, 0.1634, 0.5000],
          [0.5228, 0.0983, 0.4550]],

         [[0.4726, 0.1056, 0.4696],
          [0.5994, 0.1564, 0.6086],
          [0.4413, 0.1705, 0.6287],
          ...,
          [0.5540, 0.0711, 0.5284],
          [0.5220, 0.1639, 0.4997],
          [0.5231, 0.0983, 0.4549]]]], grad_fn=<StackBackward>)

tensor([-0.0322, -0.0322, -0.0322, -0.0322], grad_fn=<StackBackward>)