# Actor Critic class

Refereneces:

*   https://pytorch.org/docs/stable/generated/torch.where.html#torch.where
* https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html#torch.nn.Softmax
* https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
* https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.MultiStepLR.html#torch.optim.lr_scheduler.MultiStepLR
* https://pytorch.org/docs/stable/distributions.html#categorical
* https://pytorch.org/docs/stable/generated/torch.nn.functional.one_hot.html#torch.nn.functional.one_hot
* https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html#torch.repeat_interleave



https://peps.python.org/pep-0008/

In [None]:
class CPFMAgent():
  """ 
  This class represents an agent trading in a Constant Product Market for 
  cash settled futures as outlined in the paper:
  "Constant Product Market for Cash Settled Futures" by W. Gawlikowicz,
   B. Mannerings and D. Siska

  Contains all the state updates for an agent trading in this market including
  price updates, as defined in the above paper
  """
  
  def __init__(self, utility, trade_size, fee_infrastructure, R_long,
               R_short, l_search, l_initial, l_release, random_agent_update,
               variable_fees, CPFM_kw_args, verbose = False, device = 'cpu',
               **kwargs):
    """
    Initialize the CPFM market, setting all fixed parameters.

    Parameters
    ----------
    utility : partial
      the utility function used to train an agent on this market 
      (only relevant for trained agents)
    trade_size : float
      the minimum trade size for a futures trade on this market
    fee_infrastructure : float
      the fixed infrastructure fee for trades on this market
    R_long : float
      the long risk factor for the market, constant in this case
    R_short : float
      the short risk factor for the market, constant in this case
    l_search : float
      search margin scaling level (> 1)
    l_initial : float
      initial margin scaling level (> l_search)
    l_release : float 
      release margin scaling level (> l_initial)
    random_agent_upate : bool
      whether or not the price update should use random trader agents, if 
      False use the Price Match Update model. If true require that
      CPFM_kw_args contains keys 'num_traders', an int indicating the 
      number of traders to simulate, and 'traders_pay_fees', a bool
      indicating whether or not the trader agents pay fees
    variable_fees : bool
      whether or not fees are variable in the market, if False require that
      CPFM_kw_args contains key 'fee_liquidity', a float setting the fixed 
      liquidity fee
    CPFM_kw_args : dict
      contains keyword arguments necessary under certain condition
    verbose : bool, optional
      whether or not to print certain debugging messages. Defaults to False
    device : str, optional
      the device on which to run pytorch operations. Defaults to 'cpu'
    """
   
    self.verbose = verbose
    self.device = device
    # Running cost / utility function
    self.f = utility 
    # Set fees in the market
    self.variable_fees = variable_fees
    if not variable_fees: # Set the fixed liquidty fee
      self.f_L = CPFM_kw_args['fee_liquidity']
    self.f_I = fee_infrastructure  
    # Trade size
    self.ts = trade_size
    # Risk parameters
    self.R_long = R_long 
    self.R_short = R_short
    self.l_s = l_search 
    self.l_i = l_initial
    self.l_r = l_release
    # Set the price update model
    self.random_agent_update = random_agent_update
    if self.random_agent_update:
      # Set necessary parameters for random agent update model
      self.num_traders = CPFM_kw_args['num_traders']
      self.traders_pay_fees = CPFM_kw_args['traders_pay_fees']
      
  def M_long(self, A_B, A_S):
    """Get maintenance margin for long position"""
    return A_B / A_S * self.R_long

  def M_short(self, A_B, A_S):
    """Get maintenance margin for short position"""
    return A_B / A_S * self.R_short

  def pool_ratio_deltas(self, delta, A_B, A_S):
    """
    Obtain changes to A_B and A_S needed to adjust the total pool balance by 
    amount delta without changing the ratio
    """
    deltaS = delta * A_S / (A_B+A_S)
    deltaB = delta * A_B / (A_B+A_S)
    return deltaB, deltaS

  def maintain_pool_ratio(self, delta, A_B, A_S):
    """
    Update total pool balance A_B + A_S by delta without changing the ratio
    """
    deltaS, deltaB = self.pool_ratio_deltas(delta, A_B, A_S)
    A_S = A_S + deltaS
    A_B = A_B + deltaB
    return A_B, A_S
  
  def check_pool_price_update(self, A_B, A_S, dn):
    """Ensure that pool price update will not cause distressed pool"""
    P_old = A_B / A_S
    return torch.logical_and(A_B + dn*self.ts*P_old > 0,
                             A_S - dn*self.ts*P_old > 0)

  def pool_price_update(self, A_B, A_S, dn):
    """ 
    Adjust the pool position (and therefore the price) based on a trade of
    dn * trade size
    """
    P_old = A_B / A_S
    A_S = A_S - dn*self.ts*P_old
    A_B = A_B + dn*self.ts*P_old
    return A_B, A_S

  def margin(self, A_B, A_S, n, g, m, f_L, dn, close_out=False):
    """
    Update agent's positions assuming they meet margin and fee requirements
    Coresponds to sections 2.1.1 - 2.1.5 in the paper "Constant Product Market
    for Cash Settled Futures"

    dn : float
        change in position in terms of fraction of trade size, in [-1, 1]
    close_out : bool, optional
        whether or not transaction is a close out trade. Defaults to False
    """

    if type(dn) is int:
      dn = torch.ones_like(A_B, device=self.device) * dn
    # Make necessary adjustments for long or short position
    M = torch.where(n + dn >= 0, self.M_long(A_B, A_S), self.M_short(A_B, A_S))
    # Fees
    if close_out: 
      # Do not charge fees on close out trades
      f_L_loc = torch.zeros_like(dn, device=self.device)
      F_L = torch.zeros_like(dn, device=self.device)
    else:
      f_L_loc = f_L.clone()
      F_L = torch.abs(dn) * self.ts * A_B / A_S * f_L    
    f_I = self.f_I if not close_out else 0.       
    F_I = torch.abs(dn) * self.ts * A_B / A_S * f_I
    F = F_L + F_I
    # Margin required for position
    margin_req = torch.abs(n+dn) * self.ts * M * self.l_i 
    # Whether or not user is increasing their position
    is_increasing = torch.abs(n+dn) >= torch.abs(n)  
    if close_out:
      # Always allow the user to trade if closing out a position
      can_trade = torch.ones_like(A_B, device=self.device).to(torch.bool)
    else:
      # If reducing position, user can trade if fees can be paid
      # from general account. If increasing position, user can trade
      # if fees can be paid from general account and the required 
      # margin an be met from general and margin account   
      can_trade = (g >= F) & (~is_increasing | (g + m >= F + margin_req)) 
      # Ensure trade will not make either of the pools <= 0 
      can_trade = can_trade & self.check_pool_price_update(A_B, A_S, dn)
    b = torch.where(is_increasing & can_trade & (margin_req > m),
                    margin_req - m, 0.)
    n = torch.where(can_trade, n + dn, n) 
    # Only charge fee if a trade occurred, and fee not already 0
    F = torch.where(can_trade & ((f_L_loc + f_I) > 0.), F, 0.)
    F_L = torch.where(can_trade & ((f_L_loc + f_I) > 0.),
                      F * f_L / (f_L+f_I), 0.)
    F_I = torch.where(can_trade & ((f_L_loc + f_I) > 0.),
                      F * f_I / (f_L+f_I), 0.) 
    g = g - F - b 
    m = m + b
    # Adjust the pool position based on the trade
    A_B[can_trade], A_S[can_trade] =\
        self.pool_price_update(A_B[can_trade], A_S[can_trade], dn[can_trade])
    # Add fees to pools
    A_B, A_S = self.maintain_pool_ratio(F_L, A_B, A_S) 
    return A_B, A_S, n, g, m

  def mark_to_market(self, P_old, A_B, A_S, n, g, m):
      """ 
      Mark-to-market according to 2.1.6 in the paper "Constant Product Market
      for Cash Settled Futures" 
      """

      deltaP = (A_B / A_S) - P_old
      # For now assume infite margin
      delta = torch.where(n * self.ts * deltaP + m > 0,
                          n * self.ts * deltaP, -m)                   
      deltaB, deltaS = self.pool_ratio_deltas(-delta, A_B, A_S)
      # Check if pool distressed, if so do not mark to market. 
      # Withdraw margin to general account and set position to 0
      distressed = (-deltaB > A_B) | (-deltaS > A_S)
      if torch.any(distressed):
        print('Distressed pool')
      A_B = torch.where(distressed, A_B, A_B + deltaB)
      A_S = torch.where(distressed, A_S, A_S + deltaS)
      n = torch.where(distressed, 0., n)
      g = torch.where(distressed, g + m, g)
      m = torch.where(distressed, 0., m + delta)   
      return A_B, A_S, n, g, m

  def margin_checks(self, A_B, A_S, n, g, m, f_L):
    """
    Perform margin checks according to 2.1.7 in the paper "Constant Product 
    Market for Cash Settled Futures"
    """

    M = torch.where(n >= 0, self.M_long(A_B, A_S), self.M_short(A_B, A_S))
    # Top up margin from general account if needed 
    b = torch.maximum(torch.minimum(torch.abs(n) * self.ts * M * self.l_s - m,
                                    g), torch.zeros_like(n))
    m = m + b
    g = g - b
    # Check if agent distressed: Steps 3 - 4
    is_dis = ~((n == 0.) | (m > torch.abs(n) * self.ts * M))
    if torch.any(is_dis):
      #print('Distressed')
      # Submit close out trades until agent not distressed: Step 5  
      dn = torch.where(torch.abs(n[is_dis]) > 1.,
                       torch.where(n[is_dis] > 0., -1., 1.).to(torch.float64),
                       -n[is_dis])
      n_old = n[is_dis]  
      # Note: Causes recursion  
      A_B[is_dis], A_S[is_dis], n[is_dis], g[is_dis], m[is_dis] =\
        self.make_trade(A_B[is_dis], A_S[is_dis], n[is_dis], g[is_dis],
                        m[is_dis], f_L[is_dis], dn, True)
      # Note: I assume penalty based on old value of n since it is never 0
      penalty = torch.minimum(m[is_dis], 1 / torch.abs(n_old) * m[is_dis])     
      m[is_dis] = m[is_dis] - penalty
      A_B[is_dis], A_S[is_dis] =self.maintain_pool_ratio(penalty, A_B[is_dis],
                                                         A_S[is_dis])
    # Withdraw excess from margin account: Steps 6 - 7
    b = torch.maximum(m - torch.abs(n)*self.ts*M*self.l_r, torch.zeros_like(n)) 
    m = m - b
    g = g + b
    return A_B, A_S, n, g, m

  def make_trade(self, A_B, A_S, n, g, m, f_L, dn, close_out=False):
    """
    Make a long or short trade of size equal to dn * self.ts with necessary 
    updates
    """
    A_B, A_S, n, g, m = self.margin(A_B, A_S, n, g, m, f_L, dn, close_out)
    # No need to mark to market with only one trader
    A_B, A_S, n, g, m = self.margin_checks(A_B, A_S, n, g, m, f_L)
    return A_B, A_S, n, g, m

  def fee_update(self, A_B, A_S, n, g, m, f_L):
      """Placehold for state update for fees"""
      return f_L

  def trade_step(self, A_B, A_S, n, g, m, f_L, dn, s_ref):
    """
    This function is somewhat extrenuous but leave open an easy possibility of
    extending to dn of magnitude greater than 1

    dn : int
      Integer number of units of size self.ts to trade
    s_ref : 
      the price on the reference market that agents have knowledge of
    """

    if dn < 0:
      change_in_position = -dn
    else:
      change_in_position = dn
    # Save so that profit can be calculated
    g_old = g.clone()
    m_old = m.clone()
    # Adjust the position by trading the necessary number of units
    for i in range(0, change_in_position):
      A_B, A_S, n, g, m = self.make_trade(A_B, A_S, n, g, m, f_L, dn)   
    # Simulate other agents to adjust market
    if not self.random_agent_update:
      A_B, A_S, n, g, m =\
        self.oracle_settlement_price_update(s_ref, A_B, A_S, n, g, m, f_L)
    else:
      for i in range(self.num_traders):
        A_B, A_S, n, g, m = self.simulate_trader(s_ref, A_B, A_S, n, g, m, f_L)
    profit = (g+m) - (g_old+m_old)
    f_L = self.fee_update(A_B, A_S, n, g, m, f_L)
    return A_B, A_S, n, g, m, f_L, profit

  def oracle_settlement_price_update(self, P, A_B, A_S, n, g, m, f_L):
    ''' 
    Update pools A_B, A_S to match price P according to 2.2 in the paper 
    "Constant Product Market for Cash Settled Futures"
    '''
    
    P_old = A_B / A_S
    deltaP = (A_B-P*A_S) / (1+P)
    A_B = A_B - deltaP
    A_S = A_S + deltaP
    A_B, A_S, n, g, m = self.mark_to_market(P_old, A_B, A_S, n, g, m)
    A_B, A_S, n, g, m = self.margin_checks(A_B, A_S, n, g, m, f_L)   
    return  A_B, A_S, n, g, m

  def simulated_trade(self, A_B, A_S, dn, f_L):
    """Simulate a trade of dn by a trader agent"""

    # Ensure simulated trade will not cause pool to become distressed
    can_trade = self.check_pool_price_update(A_B, A_S, dn)
    # Adjust the pool position based on the trade
    A_B[can_trade], A_S[can_trade] =\
      self.pool_price_update(A_B[can_trade], A_S[can_trade], dn[can_trade])
    # Take fee from simulated trader
    if self.traders_pay_fees:
      F_L = torch.abs(dn) * self.ts * A_B / A_S * f_L
      A_B[can_trade], A_S[can_trade] =\
        self.maintain_pool_ratio(F_L[can_trade], A_B[can_trade],
                                  A_S[can_trade])
    return A_B, A_S

  def simulate_trader(self, s, A_B, A_S, n, g, m, f_L):
      """
      Simulate a trader on the market by according to which market offers
      the best price and trading with 50% probability
      """

      P_old = A_B / A_S
      # Determine what trade to make
      if self.traders_pay_fees:
        zeros = torch.zeros_like(s, dtype=torch.float64, device=self.device)
        ones = torch.ones_like(s, dtype=torch.float64, device=self.device)
        dn = torch.where(\
          P_old + self.ts*f_L*P_old < s, ones, 
          torch.where(P_old - self.ts*f_L*P_old > s, -ones, zeros)) \
            * torch.where(torch.randn(A_B.size(0), device=self.device) < 0., 0.,
                          1.)
      else:
        dn = torch.where(P_old < s, 1., -1.) \
          * torch.where(torch.randn(A_B.size(0), device=self.device) < 0., 0.,
                        1.)
      # Execute trade according to the usual steps, marking to market for
      # the trained agent
      A_B, A_S = self.simulated_trade(A_B, A_S, dn, f_L)
      A_B, A_S, n, g, m = self.mark_to_market(P_old, A_B, A_S, n, g, m)
      A_B, A_S, n, g, m = self.margin_checks(A_B, A_S, n, g, m, f_L)
      return A_B, A_S, n, g, m

  def C(self, x):
      """ As the unimplemented policy the agent always does nothing """
      
      c = torch.zeros((x.size(0),3), device=self.device) 
      c[:,2] = torch.ones_like(c[:,2])
      return c

  def step(self, x: torch.Tensor, train = False, s_next = None, **kwargs):
    """
    Do a state update step in the CPFM and reference market based on the policy
    of the agent

    Parameters
    ----------
    x: torch.Tensor
      State
    kwargs: dict
      arguments necessary to calculate a step of mid-price 
    s_next: torch.Tensor
      manually provide the next state of the reference market. This is useful
      when evaluating policites against each other to provide each agent with
      the same market conditions
    """

    c = self.C(x)
    # Extract the state variables
    s, pool_price, pool_sum, n, g, m =\
      x[:,0], x[:,1], x[:,2], x[:,3], x[:,4], x[:,5]
    # Extract the fees variable in variable fees case
    if self.variable_fees:
      f_L = x[:,6]
    else: 
      f_L = self.f_L * torch.ones_like(s, device=self.device)
    # Derive pool values for calculations
    A_S = pool_sum / (pool_price + 1)
    A_B = pool_price * A_S 
    # If not training, select an action from the policy distribution otherwise
    # use the probabilities to get an expected next state
    if not train:
      q = Categorical(c)
      c = one_hot(q.sample(), num_classes = 3)
    # Either update the reference market price by a GBM or take the value given
    if s_next is None:
      s_step = partial(GBM_step, **kwargs)
      s_next = s_step(s=s)
    else:
      s_next = s_next      
    # Long position on futures
    A_B_long, A_S_long, n_long, g_long, m_long, f_long, profit_long = \
      self.trade_step(torch.clone(A_B), torch.clone(A_S), torch.clone(n),
                      torch.clone(g), torch.clone(m), torch.clone(f_L), 1, s)
    # Short position on futures
    A_B_short, A_S_short, n_short, g_short, m_short, f_short, profit_short = \
      self.trade_step(torch.clone(A_B), torch.clone(A_S), torch.clone(n),
                      torch.clone(g), torch.clone(m), torch.clone(f_L), -1, s)
    # Do nothing
    A_B_not, A_S_not, n_not, g_not, m_not, f_not, profit_not = \
      self.trade_step(torch.clone(A_B), torch.clone(A_S), torch.clone(n),
                      torch.clone(g), torch.clone(m), torch.clone(f_L), 0, s)
    # Get the expected value of each state variable
    n = c[:,0]*n_long + c[:,1]*n_short + c[:,2]*n_not
    g = c[:,0]*g_long + c[:,1]*g_short + c[:,2]*g_not
    m = c[:,0]*m_long + c[:,1]*m_short + c[:,2]*m_not
    f_L = c[:,0]*f_long + c[:,1]*f_short + c[:,2]*f_not
    # Ensure that the price and the total pool balance are equal to their 
    # expectations, rather than the specific balances of A_B and A_S
    price = c[:,0]*A_B_long/A_S_long + c[:,1]*A_B_short/A_S_short \
      + c[:,2]*A_B_not/A_S_not
    pool_balance = c[:,0]*(A_B_long+A_S_long) + c[:,1]*(A_B_short+A_S_short) \
      + c[:,2]*(A_B_not+A_S_not)
    A_S = pool_balance / (price+1)
    A_B = pool_balance - A_S
    # Calculate expected profit
    profit = c[:,0]*profit_long + c[:,1]*profit_short + c[:,2]*profit_not
    # Recombine the state varaibles into one tensor  
    if self.variable_fees:
      x_next = torch.stack((s_next, price, pool_balance, n, g, m, f_L), dim=1)
    else:
      x_next = torch.stack((s_next, price, pool_balance, n, g, m), dim=1)
    # Calculate the running reward / utility
    running_cost = self.f(x_next, profit).reshape(-1,1)
    return c, x_next, running_cost, profit

In [None]:
class RandomAgent(CPFMAgent):
  """A CPFMAgent that selects a random action at each time-step"""
    
  def C(self, x):      
      m = nn.Softmax(dim=1)
      c = torch.randn((x.size(0),3), device=self.device) 
      return m(c)

In [None]:
class IdealOneStepAgent(CPFMAgent):
  """
  A CPFMAgent that will choose the ideal policy for a market following the
  exact price match update model
  """
  
  def C(self, x):
    # Extract the state variables
    s, price, sum, n, g, m = x[:,0], x[:,1], x[:,2], x[:,3], x[:,4], x[:,5]
    # Extract fees in variable fees case
    if self.variable_fees:
      f_L = x[:,6]
    else:
      f_L = self.f_L
    A_S = sum / (1 + price)
    A_B = price * A_S
    c = torch.zeros((x.size(0),3), device=self.device)
    # Calculate the one-step profit from each action
    profit_long = (n + 1) * (s - (A_B + A_B / A_S) / (A_S - A_B / A_S))\
      - A_B / A_S * (self.f_I + f_L)
    profit_short = (n - 1) * (s - (A_B - A_B / A_S) / (A_S + A_B / A_S))\
      - A_B / A_S * (self.f_I + f_L)
    profit_not = n * (s - A_B / A_S)
    # Set the action with the highest one step profit
    c[:, 0] = torch.where((profit_long > profit_short)\
                          & (profit_long > profit_not), 1., 0.)
    c[:, 1] = torch.where((profit_short > profit_long)\
                          & (profit_short > profit_not), 1., 0.)
    c[:, 2] = torch.where((c[:, 0] == 0.) & (c[:, 1] == 0.), 1., 0.)
    return c

In [None]:
class IdealOneStepAgentTraders(CPFMAgent):
  """
  A CPFMAgent that will choose the ideal policy for a market following the
  simulated random trader price update model
  """

  def Pr(self, s, A_B, A_S, n, dn, k, f_L):
    """Calculate the expected profit if exactly k random agents make a trade"""
    F_L = A_B / A_S * f_L
    A_B_prime = A_B + self.ts * dn * A_B / A_S
    A_S_prime = A_S - self.ts * dn * A_B / A_S 
    A_B_prime, A_S_prime = self.maintain_pool_ratio(F_L, A_B_prime, A_S_prime)
    P_prime = A_B_prime / A_S_prime
    expectation = torch.zeros_like(s)
    for i in range(k):
      zeros = torch.zeros_like(s, dtype=torch.float64)
      ones = torch.ones_like(s, dtype=torch.float64)
      if self.traders_pay_fees:
        F_L = A_B_prime / A_S_prime * f_L
      else:
        F_L = torch.zeros_like(A_B_prime)
      dn_prime = torch.where(P_prime + F_L < s, ones,
                             torch.where(P_prime - F_L > s,-ones, zeros))
      F_L = torch.where(torch.abs(dn_prime) > 0., F_L, 0.)
      A_B_new = (A_B_prime + self.ts * dn_prime * P_prime)
      A_S_new = (A_S_prime - self.ts * dn_prime * P_prime)
      A_B_new, A_S_new = self.maintain_pool_ratio(F_L, A_B_new, A_S_new)
      expectation += (n + dn) * self.ts * (A_B_new / A_S_new - P_prime)
      A_B_prime = A_B_new.clone()
      A_S_prime = A_S_new.clone()
      P_prime = A_B_prime  / A_S_prime
    return expectation

  def C(self, x):
    s, price, sum, n, g, m = x[:,0], x[:,1], x[:,2], x[:,3], x[:,4], x[:,5]
    if self.variable_fees:
      f_L = x[:,6]
    else:
      f_L = self.f_L * torch.ones_like(x[:,0])
    A_S = sum / (1 + price)
    A_B = price * A_S
    c = torch.zeros((x.size(0),3), device =self.device)
    expected_long_profit = torch.zeros_like(s, dtype=torch.float64)
    expected_short_profit = torch.zeros_like(s, dtype=torch.float64)
    expected_not_profit = torch.zeros_like(s, dtype=torch.float64)
    # Assume self.num_traders set
    for k in range(self.num_traders + 1):
      expected_long_profit += 0.5**self.num_traders\
        * scipy.special.binom(self.num_traders, k)\
        * self.Pr(s, A_B, A_S, n, 1., k, f_L)
      expected_short_profit += 0.5**self.num_traders\
        * scipy.special.binom(self.num_traders, k)\
        * self.Pr(s, A_B, A_S, n, -1., k, f_L)
      expected_not_profit += 0.5**self.num_traders\
        * scipy.special.binom(self.num_traders, k)\
        * self.Pr(s, A_B, A_S, n, 0., k, f_L)

    expected_long_profit -= A_B / A_S * (self.f_I + f_L)
    expected_short_profit -= A_B / A_S * (self.f_I + f_L)

    c[:, 0] = torch.where((expected_long_profit > expected_short_profit)\
                          & (expected_long_profit > expected_not_profit),
                          1., 0.)
    c[:, 1] = torch.where((expected_short_profit > expected_long_profit)\
                          & (expected_short_profit > expected_not_profit),
                          1., 0.)
    c[:, 2] = torch.where((c[:, 0] == 0.) & (c[:, 1] == 0.), 1., 0.)
    return c

In [None]:
def get_activation_from_str(activation_str):

  if activation_str == 'nn.Identity':
    return nn.Identity
  elif activation_str == 'nn.Softmax':
    return nn.Softmax
  elif activation_str == 'nn.ReLU':
    return nn.ReLU
  elif activation_str == 'nn.Sigmoid':
    return nn.Sigmoid
  else:
    raise ValueError('Unknown activation string')

https://towardsdatascience.com/understanding-pytorch-activation-functions-the-maths-and-algorithms-part-2-1f8bce111a7b

In [None]:
import math
from torch.autograd.grad_mode import F
class ActorCritic(CPFMAgent):
  """
  A CPFMAgent trained via an actor critic algorithm to approximate the 
  value function and optimal policy function
  """

  def __init__(self, utility : partial, trade_size : float, 
               fee_infrastructure: float, R_long: float, R_short: float,
               l_search: float, l_initial: float, l_release: float,
               random_agent_update : bool, variable_fees : bool, 
               CPFM_kw_args : dict, x0_sampler : partial, 
               discount_factor: float, dims: int, verbose = False,
               device = 'cpu', gamma = 0.1, lr = 0.005, **kwargs):
    """  
    Parameters
    ----------
    a: float
      Utility function parameter
    trade_size: float
      minimun trade size on constant product market   
    discount_factor: float
        Discount factor in Bellman equation between (0,1)
    device: str
        Device where things are run
    tau: float
        Change of time, in time discretisation
    sigma: float
        Diffusion in LQR SDE. I assume the diffusion is constant. 
        Can be easily changed
    """
    
    super().__init__(utility, trade_size, fee_infrastructure, 
                      R_long, R_short, l_search, l_initial, l_release,
                      random_agent_update, variable_fees, CPFM_kw_args, 
                      verbose = verbose, device = device, **kwargs)
    
    self.x0_sampler = x0_sampler
    self.discount_factor = discount_factor
    self.d = dims # dimension of the state (s, A_B, A_S, n, g, m)
  
    self.C = FFN(sizes = [self.d] + kwargs['hidden_dims_C'] + [3],
                  activation=get_activation_from_str(kwargs['activation_C']),
                  output_activation=\
                  get_activation_from_str(kwargs['output_activation_C']), 
                  normalize_input=kwargs['normalize_input_C'],
                  normalize_hidden=kwargs['normalize_hidden_C'],
                  normalize_output=kwargs['normalize_output_C']).to(self.device) 

    init_weights_C = partial(init_weights, var = kwargs['initial_weight_var_C'])
    self.C.apply(init_weights_C)
    
    # Optimize
    self.optimizer_C = torch.optim.Adam(self.C.parameters(), lr=lr)

    if 'scheduler' in kwargs.keys() and kwargs['scheduler'] == 'Exponential':
      # Decay the learning rate
      self.scheduler_C = \
        torch.optim.lr_scheduler.ExponentialLR(self.optimizer_C, gamma)
    else:
      # Decay the learning rate
      self.scheduler_C = \
        torch.optim.lr_scheduler.MultiStepLR(self.optimizer_C,
                                              milestones=kwargs['milestones'],
                                              gamma=gamma)
    
    # value function, input of v is x
    self.v = FFN(sizes = [self.d] + kwargs['hidden_dims_v'] + [1],
                 activation=get_activation_from_str(kwargs['activation_v']),
                 output_activation=\
                    get_activation_from_str(kwargs['output_activation_v']), 
                 normalize_input=kwargs['normalize_input_v'],
                 normalize_hidden=kwargs['normalize_hidden_v'],
                 normalize_output=kwargs['normalize_output_v']).to(self.device)

    init_weights_v = partial(init_weights, var = kwargs['initial_weight_var_v'])
    self.v.apply(init_weights_v)

    self.optimizer_v = torch.optim.Adam(self.v.parameters(), lr=lr)

    if 'scheduler' in kwargs.keys() and kwargs['scheduler'] == 'Exponential':
      # Decay the learning rate
      self.scheduler_v = \
        torch.optim.lr_scheduler.ExponentialLR(self.optimizer_v, gamma)
    else:
      self.scheduler_v = \
        torch.optim.lr_scheduler.MultiStepLR(self.optimizer_v,
                                             milestones=kwargs['milestones'],
                                             gamma=gamma)

  def step(self, x, n_mc, train = True, s_next = None, **kwargs):
      """ 
      Override the default step function to have training be True by default
      and to take the parameter n_mc for Monte-Carlo sampling 
      """
      n_batch = x.shape[0]
      x_mc = torch.repeat_interleave(x, n_mc, dim=0)
      return super().step(x_mc, train = train, s_next = s_next, **kwargs)

  
  def _dynamic_programming(self, x: torch.Tensor, n_mc: int, **kwargs):
      """
      Performs one step environment step and return bellman loss
      
      Parameters
      ----------
      x: torch.Tensor
          tensor. tensor of shape (N_batch, 3)
      n_mc: int
          Number of monte carlo samples to approximate drift
      kwargs: dict
          arguments necessary to calculate a step of mid-price 
      
      Returns
      ------
      bellman_loss: torch.Tensor
          bellman loss: ( v(x) - 1/N_mc \sum(f + delta * v(x_next)) )^2
      
      bellman_approx: torch.Tensor
          bellman approximation of v(x): 1/N_mc \sum(f + delta * v(x_next))
      
      """
      n_batch = x.shape[0]
      _, x_next, running_cost, _ = self.step(x, n_mc, train=True, **kwargs) 
      # bellman loss (equation (5) in paper)
      bellman_approx = running_cost + self.discount_factor * self.v(x_next) 
      # Average over Monte-Carlo Samples (equation (5))
      bellman_approx = bellman_approx.reshape(n_batch, n_mc, -1).mean(1)  
      # Average over batches and take power of 2 (equation (5))
      bellman_loss = torch.pow(self.v(x) - bellman_approx.detach(),2).mean()
      return bellman_loss, bellman_approx.mean()
  
  def update_alpha(self, n_batch, n_mc, **kwargs):
      """
      Gradient ascent on alpha to maximise bellman approx

      Parameters
      ----------
      n_batch: int
          batch size
      n_mc: int
          Monte Carlo size for Monte Carlo approximation of running cost to 
          have some exploration
      kwargs: dict
          arguments necessary to calculate a step of mid-price 
      """
      toggle(self.v, to=False)
      toggle(self.C, to=True)
      self.C.train()
      x0 = self.x0_sampler(n_batch) 
      self.optimizer_C.zero_grad()
      _, bellman_approx = self._dynamic_programming(x0, n_mc, **kwargs)
      bellman_approx = -1. * bellman_approx # we want to maximise!
      bellman_approx.backward()
      self.optimizer_C.step()
      self.scheduler_C.step()
      return -bellman_approx.detach()
  
  def update_v(self, n_batch, n_mc, **kwargs):
      """
      Gradient descent on to minimise bellman loss
      
      Parameters
      ----------
      n_batch: int
          batch size
      n_mc: int
          Monte Carlo size for Monte Carlo approximation of running cost to 
          have some exploration
      kwargs: dict
          arguments necessary to calculate a step of mid-price 
      """
      toggle(self.v, to=True)
      toggle(self.C, to=False)
      self.v.train()
      x0 = self.x0_sampler(n_batch)
      self.optimizer_v.zero_grad()
      bellman_loss, _ = self._dynamic_programming(x0, n_mc, **kwargs)
      bellman_loss.backward()
      self.optimizer_v.step()
      self.scheduler_v.step()
      return bellman_loss.detach()