## Convert Stable Baselines CnnPolicy to PyTorch


Github repo: https://github.com/hill-a/stable-baselines

[Documentation](http://stable-baselines.readthedocs.io/)

### Install dependencies



In [None]:
%tensorflow_version 1x

In [None]:
!pip install stable-baselines --upgrade
!git clone https://github.com/araffin/rl-baselines-zoo

In [None]:
%cd /content/rl-baselines-zoo/

/content/rl-baselines-zoo


### Import stable baselines and co

In [1]:
import gym
import torch as th
import torch.nn as nn
import numpy as np

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.cmd_util import make_atari_env

ModuleNotFoundError: No module named 'stable_baselines3.common.cmd_util'

### Load saved policy

In [None]:
model_path = "trained_agents/ppo2/{}.pkl".format('BreakoutNoFrameskip-v4')

baselines_cnn_model = PPO.load(model_path, verbose=2)

for key, value in baselines_cnn_model.get_parameters().items():
    print(key, value.shape)

### Create Pytorch model

In [None]:
class PyTorchCnnPolicy(nn.Module):
    def __init__(self):
        super(PyTorchCnnPolicy, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4, padding=0, bias=True)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=0, bias=True)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0, bias=True)
        self.fc1 = nn.Linear(3136, 512)
        self.fc2 = nn.Linear(512, 4)
        self.relu = nn.ReLU()
        self.out_activ = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        # shape before flattening
        # tf: (?, 7, 7, 64)
        # pytorch: [1, 64, 7, 7]
        x = x.permute(0, 2, 3, 1).contiguous()
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        x = self.out_activ(x)
        return x

### Convert weights from tensorflow to pytorch

In [None]:
def copy_cnn_weights(baselines_model):
    torch_cnn = PyTorchCnnPolicy()
    model_params = baselines_model.get_parameters()
    # Get only the policy parameters
    policy_keys = [key for key in model_params.keys() if "pi" in key or "c" in key]
    policy_params = [model_params[key] for key in policy_keys]

    for (th_key, pytorch_param), key, policy_param in zip(torch_cnn.named_parameters(), policy_keys, policy_params):
        param = policy_param.copy()
        # Copy parameters from stable baselines model to pytorch model

        # Conv layer
        if len(param.shape) == 4:
          # https://gist.github.com/chirag1992m/4c1f2cb27d7c138a4dc76aeddfe940c2
          # Tensorflow 2D Convolutional layer: height * width * input channels * output channels
          # PyTorch 2D Convolutional layer: output channels * input channels * height * width
          param = np.transpose(param, (3, 2, 0, 1))

        # weight of fully connected layer
        if len(param.shape) == 2:
            param = param.T

        # bias
        if 'b' in key:
            param = param.squeeze()

        param = th.from_numpy(param)
        pytorch_param.data.copy_(param.data.clone())

    return torch_cnn

### Convert images to torch format

In [None]:
def obs_to_torch(obs):
    # TF: NHWC
    # PyTorch: NCHW
    # https://discuss.pytorch.org/t/dimensions-of-an-input-image/19439
    obs = np.transpose(obs, (0, 3, 1, 2))
    # Normalize
    obs = obs / 255.0
    obs = th.tensor(obs).float()
    return obs

### Sanity check with one observation

In [None]:
th_model = copy_cnn_weights(baselines_cnn_model)

In [None]:
env = make_atari_env('BreakoutNoFrameskip-v4', num_env=1, seed=0)
# Frame-stacking with 4 frames
env = VecFrameStack(env, n_stack=4)

In [None]:
obs = env.reset()

In [None]:
baselines_cnn_model.action_probability(obs)

array([[9.9991703e-01, 7.7686826e-05, 3.3082727e-06, 1.9453785e-06]],
      dtype=float32)

In [None]:
th_model(obs_to_torch(obs))

tensor([[9.9992e-01, 7.7687e-05, 3.3083e-06, 1.9454e-06]],
       grad_fn=<SoftmaxBackward>)

### Test with trained model and same random seed

In [None]:
env = make_atari_env('BreakoutNoFrameskip-v4', num_env=1, seed=1)
# Frame-stacking with 4 frames
env = VecFrameStack(env, n_stack=4)

episode_reward = 0
done = False
obs = env.reset()
while not done:
    action = th.argmax(th_model(obs_to_torch(obs))).item()
    #action = env.action_space.sample()
    obs, reward, done, _ = env.step([action])
    episode_reward += reward

print(episode_reward)

[87.]


In [None]:
env = make_atari_env('BreakoutNoFrameskip-v4', num_env=1, seed=1)
# Frame-stacking with 4 frames
env = VecFrameStack(env, n_stack=4)

episode_reward = 0
done = False
obs = env.reset()
while not done:
    action, _ = baselines_cnn_model.predict(obs, deterministic=True)
    obs, reward, done, _ = env.step(action)
    episode_reward += reward

print(episode_reward)

[87.]


### Sanity check: random agent

In [None]:
env = make_atari_env('BreakoutNoFrameskip-v4', num_env=1, seed=1)
# Frame-stacking with 4 frames
env = VecFrameStack(env, n_stack=4)

episode_reward = 0
done = False
obs = env.reset()
while not done:
    action = env.action_space.sample()
    obs, reward, done, _ = env.step([action])
    episode_reward += reward

print(episode_reward)

[0.]
