# Pretrain Policy with  `MPC dataset`

In [1]:
import numpy as np
import pickle
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import torch.optim as optim

from tqdm import tqdm

from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from policy import GaussianPolicy
from rl_games.algos_torch.running_mean_std import RunningMeanStd


  from .autonotebook import tqdm as notebook_tqdm


### Load  `MPC dataset`

In [2]:
with open(file='../data/dataset.pkl', mode='rb') as f:
    dataset = pickle.load(f)

action_batch = torch.Tensor(dataset['action'])
qpos_batch = torch.Tensor(dataset['qpos'])
qvel_batch = torch.Tensor(dataset['qvel'])
obs_batch = torch.cat((qpos_batch, qvel_batch), dim=1)

print("action : ", action_batch.shape)
print("qpos : ", qpos_batch.shape)
print("qvel : ", qvel_batch.shape)

obs_dim = obs_batch.shape[1]
action_dim = action_batch.shape[1]
hidden_dim = 256

del dataset

action :  torch.Size([1136000, 26])
qpos :  torch.Size([1136000, 50])
qvel :  torch.Size([1136000, 49])


In [3]:
action_batch[1]

tensor([ 0.2862, -0.2742, -0.0915, -0.0823,  0.2918, -0.0046, -0.0993, -0.0307,
         0.2906, -0.2419, -0.0566,  0.2426, -0.0800,  0.1099, -0.0554, -0.1835,
         0.1174, -0.2895,  0.0853,  0.0523,  0.1197, -0.0270, -0.3124, -0.1802,
         0.0037, -0.0117])

In [4]:
policy = GaussianPolicy(
    input_dim=obs_dim,
    output_dim=action_dim,
    hidden_dim=hidden_dim,
    is_deterministic=False,
)

In [5]:
class MPCDataset(Dataset):
    def __init__(self, obs, act):
        self.obs = obs
        self.act = act
        assert self.obs.shape[0] == self.act.shape[0]

    def __len__(self):
        return self.obs.shape[0]

    def __getitem__(self,idx):
        return self.obs[idx], self.act[idx]

In [6]:
train_dataset = MPCDataset(obs_batch, action_batch)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

In [7]:
num_epoch = 100
optimizer = optim.Adam(policy.parameters(), lr=3e-4)
loss = torch.nn.MSELoss()
def criterion(output: torch.tensor, y: torch.tensor):
    return loss(output, y)

for epoch in range(num_epoch):

    with tqdm(train_dataloader, unit="batch") as tepoch:
        
        for x, y in tepoch:
            
            tepoch.set_description(f"Epoch {epoch+1}")

            optimizer.zero_grad()
                
            output, _ = policy(x)
            l = criterion(output, y)
            l.backward()
            optimizer.step()
            
            tepoch.set_postfix(loss=l.item())


Epoch 1:   0%|          | 0/8875 [00:00<?, ?batch/s, loss=0.701]

Epoch 1: 100%|██████████| 8875/8875 [00:30<00:00, 293.32batch/s, loss=0.0147]
Epoch 2: 100%|██████████| 8875/8875 [00:28<00:00, 310.59batch/s, loss=0.0111]
Epoch 3: 100%|██████████| 8875/8875 [00:28<00:00, 310.48batch/s, loss=0.0114] 
Epoch 4: 100%|██████████| 8875/8875 [00:29<00:00, 299.96batch/s, loss=0.0105] 
Epoch 5: 100%|██████████| 8875/8875 [00:28<00:00, 309.99batch/s, loss=0.00956]
Epoch 6: 100%|██████████| 8875/8875 [00:28<00:00, 309.86batch/s, loss=0.00843]
Epoch 7: 100%|██████████| 8875/8875 [00:28<00:00, 315.37batch/s, loss=0.00856]
Epoch 8: 100%|██████████| 8875/8875 [00:29<00:00, 296.06batch/s, loss=0.0079] 
Epoch 9: 100%|██████████| 8875/8875 [00:30<00:00, 295.57batch/s, loss=0.00752]
Epoch 10: 100%|██████████| 8875/8875 [00:29<00:00, 297.41batch/s, loss=0.00839]
Epoch 11: 100%|██████████| 8875/8875 [00:29<00:00, 297.97batch/s, loss=0.00776]
Epoch 12: 100%|██████████| 8875/8875 [00:30<00:00, 294.20batch/s, loss=0.00773]
Epoch 13: 100%|██████████| 8875/8875 [00:29<00:00, 

In [None]:
torch.save(policy.state_dict(), "pretrained.pth")

In [None]:
torch.load("pretrained.pth")

OrderedDict([('fc_layers.0.weight',
              tensor([[ 0.1215, -0.1565,  0.1505,  ...,  0.0220, -0.0733,  0.0974],
                      [-0.0485, -0.0285,  0.0772,  ..., -0.0013,  0.1378,  0.0375],
                      [ 0.0746, -0.1065,  0.1095,  ...,  0.0401, -0.0087,  0.0433],
                      ...,
                      [-0.0819, -0.0598,  0.1581,  ...,  0.1039, -0.0975, -0.0007],
                      [-0.3158,  0.0874,  0.1696,  ..., -0.0365, -0.0675,  0.0555],
                      [ 0.0256, -0.0684,  0.1021,  ...,  0.0195,  0.0358, -0.0326]])),
             ('fc_layers.0.bias',
              tensor([-0.0241,  0.0457,  0.0638,  0.1623,  0.0659, -0.0092,  0.0527,  0.1868,
                      -0.0146, -0.1315,  0.0280,  0.2140,  0.0372,  0.0382, -0.0710,  0.1094,
                       0.1125,  0.0447, -0.0277,  0.0893,  0.1810, -0.0226, -0.0008,  0.0929,
                       0.1461,  0.0260,  0.0266, -0.0108, -0.0858,  0.0120,  0.0214,  0.1273,
                    