# Pretrain Policy with  `MPC dataset`

In [1]:
import numpy as np
import pickle
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import torch.optim as optim

from tqdm import tqdm

from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from policy import GaussianPolicy

  from .autonotebook import tqdm as notebook_tqdm


### Load  `MPC dataset`

In [2]:
with open(file='../data/dataset.pkl', mode='rb') as f:
    dataset = pickle.load(f)

action_batch = torch.Tensor(dataset['action'])
qpos_batch = torch.Tensor(dataset['qpos'])
qvel_batch = torch.Tensor(dataset['qvel'])
obs_batch = torch.cat((qpos_batch, qvel_batch), dim=1)

print("action : ", action_batch.shape)
print("qpos : ", qpos_batch.shape)
print("qvel : ", qvel_batch.shape)

obs_dim = obs_batch.shape[1]
action_dim = action_batch.shape[1]
hidden_dim = 512

del dataset

In [None]:
policy = GaussianPolicy(
    input_dim=obs_dim,
    output_dim=action_dim,
    hidden_dim=hidden_dim,
    is_deterministic=False,
)

RunningMeanStd:  99


In [None]:
class MPCDataset(Dataset):
    def __init__(self, obs, act):
        self.obs = obs
        self.act = act
        assert self.obs.shape[0] == self.act.shape[0]

    def __len__(self):
        return self.obs.shape[0]

    def __getitem__(self,idx):
        return self.obs[idx], self.act[idx]

In [None]:
train_dataset = MPCDataset(obs_batch, action_batch)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

In [None]:
num_epoch = 100
optimizer = optim.Adam(policy.parameters(), lr=3e-4)
loss = torch.nn.MSELoss()
def criterion(output: torch.tensor, y: torch.tensor):
    return loss(output, y)

for epoch in range(num_epoch):

    with tqdm(train_dataloader, unit="batch") as tepoch:
        
        for x, y in tepoch:
            
            tepoch.set_description(f"Epoch {epoch+1}")

            optimizer.zero_grad()
                
            output, _ = policy(x)
            l = criterion(output, y)
            l.backward()
            optimizer.step()
            
            tepoch.set_postfix(loss=l.item())


Epoch 1: 100%|██████████| 8875/8875 [01:05<00:00, 135.91batch/s, loss=0.00947]
Epoch 2: 100%|██████████| 8875/8875 [01:04<00:00, 136.82batch/s, loss=0.00685]
Epoch 3: 100%|██████████| 8875/8875 [01:02<00:00, 142.52batch/s, loss=0.0053] 
Epoch 4: 100%|██████████| 8875/8875 [01:01<00:00, 143.59batch/s, loss=0.00489]
Epoch 5: 100%|██████████| 8875/8875 [01:04<00:00, 138.42batch/s, loss=0.0044] 
Epoch 6: 100%|██████████| 8875/8875 [01:08<00:00, 128.94batch/s, loss=0.00427]
Epoch 7: 100%|██████████| 8875/8875 [01:11<00:00, 123.78batch/s, loss=0.0041] 
Epoch 8: 100%|██████████| 8875/8875 [01:12<00:00, 122.58batch/s, loss=0.00407]
Epoch 9: 100%|██████████| 8875/8875 [01:09<00:00, 128.37batch/s, loss=0.00401]
Epoch 10: 100%|██████████| 8875/8875 [01:07<00:00, 131.45batch/s, loss=0.00386]
Epoch 11: 100%|██████████| 8875/8875 [01:07<00:00, 131.49batch/s, loss=0.00376]
Epoch 12: 100%|██████████| 8875/8875 [01:07<00:00, 131.53batch/s, loss=0.00384]
Epoch 13: 100%|██████████| 8875/8875 [01:07<00:00

In [None]:
# torch.save(policy.state_dict(), "pretrained.pth")
torch.save(policy.running_mean_std.state_dict(), "x.pth")

In [None]:
torch.load("pretrained.pth")

OrderedDict([('fc_layers.0.weight',
              tensor([[-3.5663e-02, -6.2427e-02, -6.6950e-02,  ...,  2.3967e-02,
                        5.9153e-04, -2.2630e-02],
                      [ 1.1431e-02,  2.7153e-02,  5.0848e-03,  ..., -8.4272e-03,
                        1.2081e-04,  6.6997e-03],
                      [ 2.3380e-02,  8.8157e-02,  1.2871e-01,  ...,  4.3075e-03,
                       -3.2576e-03,  2.0511e-02],
                      ...,
                      [-7.2053e-02,  1.2269e-02,  1.6802e-01,  ..., -3.0039e-02,
                       -4.1080e-02, -2.3112e-02],
                      [-6.7006e-02, -1.2783e-01,  1.8633e-01,  ..., -2.7223e-02,
                       -1.8189e-02, -7.2118e-03],
                      [ 7.5367e-02,  1.0303e-01, -3.7449e-02,  ...,  4.7731e-03,
                        3.0344e-02,  3.4120e-02]])),
             ('fc_layers.0.bias',
              tensor([-0.1096,  0.0961, -0.3489, -0.3267, -0.0381, -0.2351, -0.0556, -0.1216,
                    