# Pretrain Policy with  `MPC dataset`

In [1]:
import numpy as np
import pickle
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import torch.optim as optim

from tqdm import tqdm

from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from policy import GaussianPolicy

  from .autonotebook import tqdm as notebook_tqdm


### Load  `MPC dataset`

In [2]:
with open(file='../data/qposnoise smpl stand v3.pkl', mode='rb') as f:
    dataset = pickle.load(f)

horizon = 300
action_batch = torch.Tensor(dataset['action'])
qpos_batch = torch.Tensor(dataset['qpos'])
qvel_batch = torch.Tensor(dataset['qvel'])

action_batch = action_batch.reshape(action_batch.shape[0] * horizon, -1)
qpos_batch = qpos_batch.reshape(qpos_batch.shape[0] * horizon, -1)
qvel_batch = qvel_batch.reshape(qvel_batch.shape[0] * horizon, -1)

del dataset

qpos_batch = qpos_batch[:, 2:]
# qpos_batch = torch.cat((qpos_batch[:-2, :],qpos_batch[1:-1, :],qpos_batch[2:, :]),dim=1)
# qvel_batch = torch.cat((qvel_batch[:-2, :],qvel_batch[1:-1, :],qvel_batch[2:, :]),dim=1)

obs_batch = torch.cat((qpos_batch, qvel_batch), dim=1)

# action_batch = action_batch[2:]
obs_dim = obs_batch.shape[1]
action_dim = action_batch.shape[1]
hidden_dim = 512

print("obs : ", obs_batch.shape)
print("action : ", action_batch.shape)


obs :  torch.Size([30000, 85])
action :  torch.Size([30000, 37])


In [3]:
policy = GaussianPolicy(
    input_dim=obs_dim,
    output_dim=action_dim,
    hidden_dim=hidden_dim,
    is_deterministic=False,
)

RunningMeanStd:  85
RunningMeanStd:  37


In [4]:
class MPCDataset(Dataset):
    def __init__(self, obs, act):
        self.obs = obs
        self.act = act
        assert self.obs.shape[0] == self.act.shape[0]

    def __len__(self):
        return self.obs.shape[0]

    def __getitem__(self,idx):
        return self.obs[idx], self.act[idx]

In [5]:
train_dataset = MPCDataset(obs_batch, action_batch)
train_dataloader = DataLoader(train_dataset, batch_size=2048, shuffle=True)
# test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

In [6]:
num_epoch = 1000
optimizer = optim.Adam(policy.parameters(), lr=5e-4, eps=1e-5, betas=(0.9, 0.95))
loss = torch.nn.MSELoss()
def criterion(output: torch.tensor, y: torch.tensor):
    diff = output - y
    return (torch.mean(torch.abs(diff)))
    # return torch.mean(torch.sqrt(torch.mean(torch.square(diff),dim=1)))
    # return loss(output, y)

for epoch in range(num_epoch):

    with tqdm(train_dataloader, unit="batch") as tepoch:
        
        for x, y in tepoch:
            
            tepoch.set_description(f"Epoch {epoch+1}")

            optimizer.zero_grad()
                
            output, _ = policy(x)
            l = criterion(output, y)
            l.backward()
            optimizer.step()
            
            tepoch.set_postfix(loss=l.item())


Epoch 1:   0%|          | 0/15 [00:00<?, ?batch/s]

Epoch 1: 100%|██████████| 15/15 [00:00<00:00, 33.22batch/s, loss=0.189]
Epoch 2: 100%|██████████| 15/15 [00:00<00:00, 34.38batch/s, loss=0.0997]
Epoch 3: 100%|██████████| 15/15 [00:00<00:00, 34.52batch/s, loss=0.0714]
Epoch 4: 100%|██████████| 15/15 [00:00<00:00, 35.68batch/s, loss=0.0583]
Epoch 5: 100%|██████████| 15/15 [00:00<00:00, 34.76batch/s, loss=0.0522]
Epoch 6: 100%|██████████| 15/15 [00:00<00:00, 33.94batch/s, loss=0.0516]
Epoch 7: 100%|██████████| 15/15 [00:00<00:00, 34.14batch/s, loss=0.0444]
Epoch 8: 100%|██████████| 15/15 [00:00<00:00, 34.74batch/s, loss=0.0417]
Epoch 9: 100%|██████████| 15/15 [00:00<00:00, 34.34batch/s, loss=0.0398]
Epoch 10: 100%|██████████| 15/15 [00:00<00:00, 35.47batch/s, loss=0.0366]
Epoch 11: 100%|██████████| 15/15 [00:00<00:00, 35.04batch/s, loss=0.0351]
Epoch 12: 100%|██████████| 15/15 [00:00<00:00, 36.02batch/s, loss=0.0409]
Epoch 13: 100%|██████████| 15/15 [00:00<00:00, 36.59batch/s, loss=0.0404]
Epoch 14: 100%|██████████| 15/15 [00:00<00:00, 3

In [10]:
torch.save(policy.state_dict(), "pretrained_stand qposnoise.pth")

In [15]:
torch.load("pretrained_stand.pth")

OrderedDict([('fc_layers.0.weight',
              tensor([[ 0.0198,  0.0176,  0.0063,  ..., -0.0115, -0.0359, -0.0701],
                      [ 0.0630,  0.0294,  0.0169,  ..., -0.0011,  0.0301,  0.0283],
                      [-0.0802,  0.0443,  0.0835,  ...,  0.0287,  0.0157, -0.0182],
                      ...,
                      [ 0.0817,  0.0680, -0.0971,  ...,  0.0225,  0.0563,  0.0962],
                      [ 0.0821,  0.1084, -0.0794,  ..., -0.0161,  0.0169, -0.0064],
                      [-0.1018,  0.0609,  0.0996,  ...,  0.0138, -0.0244, -0.0263]])),
             ('fc_layers.0.bias',
              tensor([ 8.2246e-02, -9.4981e-02, -1.2945e-01, -1.9070e-01,  1.3913e-02,
                      -5.3622e-02, -4.1444e-04, -6.2849e-02, -1.1321e-01, -1.6851e-01,
                       3.1859e-02, -1.7986e-01, -1.0379e-01, -5.1795e-02, -1.9948e-01,
                       3.1596e-02,  1.0156e-02, -1.5771e-01,  5.3334e-02, -1.8457e-01,
                       1.9521e-02,  3.8616e-02, 