# Policy
## Overview
In Tianshou, both the agent and the core DRL algorithm are implemented in the Policy module. Tianshou provides more than 20 Policy modules, each representing one DRL algorithm. All Policy modules inherit from a BasePolicy Class and share the same interface.

## Creating you own Policy
We will use the simple REINFORCE algorithm Policy to show the implementation of a Policy Module.

### Initialisation
Firstly we create the `REINFORCEPolicy` by inheriting `BasePolicy` in Tianshou.

In [1]:
from typing import Any, Dict, List, Optional, Type, Union
import numpy as np
import torch
from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as
from tianshou.policy import BasePolicy

class REINFORCEPolicy(BasePolicy):
    """Implementation of REINFORCE algorithm"""
    def __init__(self):
        super().__init__()

The Policy Module mainly does two things:
1. `policy.forward()` receives observation and other information (stored in a Batch) from the environment and returns a new Batch containing the action
2. `policy.update()` receives training data sampled from the replay buffer and updates itself, and then returns logging details

We also need to take care of the following things:
1. Since Tianshou is a Deep RL library, there should be a policy network in out Policy Module, also a Torch optimiser
2. In Tianshou's BasePolicy, `policy.update()` first calls `Policy.process_fn()` to preprocess training data and computes quantities like episodic returns (gradient free), then it will call `Policy.learn() to perform the back-propagation`

In [3]:
from typing import Any, Dict
import numpy as np
from tianshou.data import Batch, ReplayBuffer


class REINFORCEPolicy(BasePolicy):
    """Implementation of REINFORCE algorithm"""
    def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer):
        super().__init__()
        self.actor = model
        self.optim = optim
    
    def forward(self, batch: Batch) -> Batch:
        """Compute action over the given batch data"""
        act = None
        return Batch(act=act)
    
    def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch:
        """compute the discounted returns for each trastion"""
        pass

    def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:
        """perform the back-propagation"""
        return

### Policy.forward()
According to the equation of REINFORCE algorithm in Spinning Up's documentation, we need to map the observation to an action distribution in action space using neural network (`self.actor`)
$$\hat{g}=\frac{1}{|D|}\sum_{\tau\in D}{\sum_{t=0}^{T}{\nabla_{\theta}\log\pi_{\theta}(a_{t}|s_{t})R(\tau)}}$$
Let us suppose the action space is discrete, and the distribution is a simple categorical distribution.

In [4]:
def forward(self, batch: Batch) -> Batch:
    """compute action over the given bacth data"""
    self.dist_fn = torch.distributions.Categorical
    logits = self.actor(batch.obs)
    dist = self.dist_fn(logits)
    act = dist.sample()
    return Batch(act=act, dist=dist)

### Policy.process_fn()
Now that we have defined out actor, if given training data we can set up a loss function and optimise our neural network. However, before that we must first calculate episodic returns for every step in out training data to construct the REINFORCE loss function.

Calculating episodeic return is not hard, given `ReplayBuffer.next()` allows us to access every reward to go in an episode. A more convenient way would be to simply use the built-in method `BasePolicy.compute_episodic_return` inherited from BasePolicy.

In [5]:
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch:
    """compute the discounted returns for each transition"""
    returns, _ = self.compute_episodic_return(batch, buffer, indices, gamma=0.99, gae_lambda=1.0)
    batch.returns = returns
    return batch

`BasePolicy.compute_episodic_return()` could also be used to compute GAE. Another similar method is `BasePolicy. compute_nstep_return()`.

### Policy.learn()
Data batch returned by `Policy.process_fn` will flow into `Policy.learn()`. Finally we can construct our loss function and perform the back-propagation.

In [6]:
def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:
    """perform the back-propagation"""
    logging_losses = []
    for _ in range(repeat):
        for minibatch in batch.split(batch_size, merge_last=True):
            self.optim.zero_grad()
            result = self(minibatch)
            dist = result.dist
            act = to_torch_as(minibatch.act, result.act)
            ret = to_torch(minibatch.returns, torch.float, result.act.device)
            log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)
            loss = -(log_prob * ret).mean()
            loss.backward()
            self.optim.step()
            logging_losses.append(loss.item())
    return {"loss": logging_losses}

### Implementation

In [12]:
from typing import Any, Dict
import numpy as np
from tianshou.data import Batch, ReplayBuffer


class REINFORCEPolicy(BasePolicy):
    """Implementation of REINFORCE algorithm"""
    def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer):
        super().__init__()
        self.actor = model
        self.optim = optim
        self.dist_fn = torch.distributions.Categorical

    def forward(self, batch: Batch) -> Batch:
        """compute action over the given bacth data"""
        logits, _ = self.actor(batch.obs)
        dist = self.dist_fn(logits)
        act = dist.sample()
        return Batch(act=act, dist=dist)
    
    def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch:
        """compute the discounted returns for each transition"""
        returns, _ = self.compute_episodic_return(batch, buffer, indices, gamma=0.99, gae_lambda=1.0)
        batch.returns = returns
        return batch
    
    def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:
        """perform the back-propagation"""
        logging_losses = []
        for _ in range(repeat):
            for minibatch in batch.split(batch_size, merge_last=True):
                self.optim.zero_grad()
                result = self(minibatch)
                dist = result.dist
                act = to_torch_as(minibatch.act, result.act)
                ret = to_torch(minibatch.returns, torch.float, result.act.device)
                log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)
                loss = -(log_prob * ret).mean()
                loss.backward()
                self.optim.step()
                logging_losses.append(loss.item())
        return {"loss": logging_losses}

## Use the policy
Note that `BasePolicy` itself inherits from `torch.nn.Module`. As a result, you can consider all Policy modules as a Torch Module. They share similar APIs.

Firstly we will initialise a new REINFORCE policy.

In [13]:
from tianshou.utils.net.common import Net
from tianshou.utils.net.discrete import Actor
import warnings
warnings.filterwarnings('ignore')
state_shape = 4
action_shape = 2
net = Net(state_shape, hidden_sizes=[16, 16], device="cpu")
actor= Actor(net, action_shape, device="cpu").to("cpu")
optim = torch.optim.Adam(actor.parameters(), lr=0.0003)

policy = REINFORCEPolicy(actor, optim)

In [14]:
print(policy)
print("==================================")
for para in policy.parameters():
    print(para.shape)

REINFORCEPolicy(
  (actor): Actor(
    (preprocess): Net(
      (model): MLP(
        (model): Sequential(
          (0): Linear(in_features=4, out_features=16, bias=True)
          (1): ReLU()
          (2): Linear(in_features=16, out_features=16, bias=True)
          (3): ReLU()
        )
      )
    )
    (last): MLP(
      (model): Sequential(
        (0): Linear(in_features=16, out_features=2, bias=True)
      )
    )
  )
)
torch.Size([16, 4])
torch.Size([16])
torch.Size([16, 16])
torch.Size([16])
torch.Size([2, 16])
torch.Size([2])


### Making decision
Given a batch of observations, the policy can return a batch of actions and other data.

In [15]:
obs_batch = Batch(obs=np.ones(shape=(256, 4)))
action = policy(obs_batch)  # forward() method is called
print(action)

Batch(
    act: tensor([0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1,
                 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1,
                 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0,
                 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1,
                 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1,
                 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1,
                 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
                 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1,
                 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0]),
    dist: Categorical(probs: torch.S

### Save and Load models
Naturally, Tianshou Policy can be saved and loaded like a normal Torch network.

In [16]:
torch.save(policy.state_dict(), 'policy.pth')
assert policy.load_state_dict(torch.load('policy.pth'))

### Algorithm Updating
We have to collect some data and save them in the ReplayBuffer updating our agent(policy). Typically we use collector to collect data, but we leave this part till later when we have learned the Collector in Tianshou. For now we generate some fake data.

#### Generating fake data
Firstly, we need to "pretend" that we are using the "Policy" to collect data. We plan to collect 10 data so that we can update our algorithm.

In [70]:
import gymnasium as gym
from tianshou.data import Batch, ReplayBuffer
# a buffer is initialised with its maxsize set to 12
buf = ReplayBuffer(size=12)
print(buf)
print("maxsize: {}, data length: {}".format(buf.maxsize, len(buf)))
env = gym.make("CartPole-v0")

ReplayBuffer()
maxsize: 12, data length: 0


Now we are pretending to collect the first episode. The first episode ends at step 3 (perhaps we are performing too badly).

In [71]:
obs = env.reset()[0]
for i in range(3):
    act = policy(Batch(obs=obs[np.newaxis, :])).act.item()
    obs_next, rew, terminated, truncated, info = env.step(act)
    # pretend ending at step 3
    terminated = True if i==2 else False
    info["id"] = i
    buf.add(Batch(obs=obs, act=act, rew=rew, terminated=terminated, truncated=truncated, obs_next=obs_next, info=info))
    obs = obs_next

Now we are pretending to collect the second episode. At step 7 the second episode still doesn't end, but we are unwilling to wait, so we stop collecting to update the algorithm.

In [72]:
obs = env.reset()[0]
for i in range(3, 10):
    act = policy(Batch(obs=obs[np.newaxis, :])).act.item()
    obs_next, rew, terminated, truncated, info = env.step(act)
    # pretend this episode never end
    terminated = False
    truncated = False
    info["id"] = i
    buf.add(Batch(obs=obs, act=act, rew=rew, terminated=terminated, truncated=truncated, obs_next=obs_next, info=info))
    obs = obs_next

Our replay buffer looks like this now.

In [73]:
print(buf)
print("maxsize: {}, data length: {}".format(buf.maxsize, len(buf)))

ReplayBuffer(
    obs: array([[ 0.00588752, -0.01020907, -0.03662223, -0.00836865],
                [ 0.00568334,  0.18541844, -0.0367896 , -0.3123777 ],
                [ 0.00939171, -0.00916062, -0.04303716, -0.0315203 ],
                [ 0.04263549,  0.00945875,  0.01780597,  0.01071217],
                [ 0.04282467, -0.18591398,  0.01802021,  0.30895948],
                [ 0.03910639,  0.00894665,  0.0241994 ,  0.02201366],
                [ 0.03928532,  0.20371334,  0.02463967, -0.26293692],
                [ 0.04335959,  0.00824851,  0.01938093,  0.03741467],
                [ 0.04352456,  0.20308726,  0.02012923, -0.24909092],
                [ 0.0475863 ,  0.39791605,  0.01514741, -0.5353573 ],
                [ 0.        ,  0.        ,  0.        ,  0.        ],
                [ 0.        ,  0.        ,  0.        ,  0.        ]],
               dtype=float32),
    info: Batch(
              id: array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0]),
          ),
    rew: array([1., 1

#### Updates
Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train.

In [74]:
# 0 means sample all data from the buffer
# batch_size=10 defines the training batch size
# repeat=6 means repeat the training for 6 times
policy.update(0, buf, batch_size=10, repeat=6)

{'loss': [2.2938356399536133,
  2.2932286262512207,
  2.2926220893859863,
  2.292015552520752,
  2.2914087772369385,
  2.290802001953125]}