In [None]:
# default_exp models.layers.ou_noise

# OU Noise Layer
> Implementation of Ornstein Uhlenbeck Noise.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
import torch

In [None]:
#export
class OUNoise(object):
    """
    Ornstein-Uhlenbeck Noise
    """
    def __init__(self,
                 embedded_action_size,
                 ou_mu,
                 ou_theta,
                 ou_sigma,
                 ou_epsilon):
        """
        Initialize OUNoise
        """
        self.embedded_action_size = embedded_action_size
        self.ou_mu = ou_mu
        self.ou_theta = ou_theta
        self.ou_sigma = ou_sigma
        self.ou_epsilon = ou_epsilon
        self.ou_state = None
        self.reset()

    def reset(self):
        """
        Reset the OU process state
        """
        self.ou_state = torch.ones(self.embedded_action_size) * self.ou_mu

    def evolve_state(self):
        """
        Evolve the OU process state
        """
        self.ou_state += self.ou_theta * (self.ou_mu - self.ou_state) \
            + self.ou_sigma * torch.randn(self.embedded_action_size)

    def get_ou_noise(self):
        """
        Get the OU noise for one action
        :return OU noise
        """
        self.evolve_state()
        return self.ou_state.clone()

In [None]:
noise = OUNoise(embedded_action_size = 32,
                ou_mu = 0.0,
                ou_theta = 0.15,
                ou_sigma = 0.2,
                ou_epsilon = 1.0,
)
noise.reset()
noise.get_ou_noise()

tensor([ 0.1521,  0.2014,  0.1552,  0.0324, -0.3615,  0.2302, -0.1692, -0.1778,
         0.0083,  0.0672,  0.0395,  0.0195,  0.2221,  0.0395, -0.1529, -0.2925,
         0.0994, -0.0426,  0.0901,  0.2552,  0.2225, -0.0833, -0.2342, -0.1982,
         0.2106,  0.1198,  0.2413,  0.3598,  0.0585,  0.0011,  0.2037, -0.0334])

In [None]:
#hide
!pip install -q watermark
%reload_ext watermark
%watermark -a "Sparsh A." -m -iv -u -t -d

Author: Sparsh A.

Last updated: 2021-12-19 09:15:53

Compiler    : GCC 7.5.0
OS          : Linux
Release     : 5.4.104+
Machine     : x86_64
Processor   : x86_64
CPU cores   : 2
Architecture: 64bit

torch  : 1.10.0+cu111
IPython: 5.5.0

