# Class Inheritance
:label:`inheritance`

As we have seen, classes allow you to define new types that can behave and feel
like Python built-in types or types from the `numpy` module. 

They allow you to group attributes and methods to operate on them. You can encapsulate
behaviour that is specific to your types. They also support operator overloading,
allowing you to define what standard operators like `__call__`, `__str__`, and `__mul__`
do when applied to your own types.

Classes define types. You instantiate or construct mutliple object instances from the blueprint
provided by your class. For example, given a class `CallPayoff` you can instantiate mutliple objects with different strikes.

Through these features classes support the objective of maximizing code reuse even further than functions when applicable.

In this section, we will introduce another concept supported by classes: inheritance. Inheritance is a language feature that provides a way to define hierarchies of types with subclasses inheriting from superclasses, to customize and specialize the types.

## Defining Inheritance

Inheritance can help us to represent objects which have some differences and some 
similarities in the way they work. We can put all the functionality that the objects 
have in common in a base class, and then define one or more subclasses with their 
own custom functionality.

Subclasses lower in the inheritance hierarchy specialize behavior of classes higher 
up in the hierarchy by overriding the more general definitions of attributes/methods 
higher in the tree. The more general types are high up in the hierarchy, while more 
specialized classes are found lower down the hierarchy. Classes inherit all of the 
attributes from their ancestors and can override some.

## Motivating application

For instance, we are going to define a new Monte-Carlo pricing function
that can price path-dependent options for some yet undefined option types. 
For now, we will not use vectorized operations and just do a simple loop 
over the paths.

Because a path-dependent derivative's payoff depends on the value of the 
underlying at multiple times, the function needs to:

- query the option object for the times that are needed
- update the stochastic process repeatedly through the times
- hand over the different values of the underlying at the different times

### A Monte-Carlo pricing routine for path-dependent options

Here is one possible implementation. Although not the best one, this serves 
to illustrate the topic:

In [1]:
def mc_path_dependent(option_state, process, yield_curve, n_paths):
    """Computes the fair value of a path-dependent option using Monte-Carlo simulations.
    
    Parameters
    ----------
    option_state : OptionState
        The option payoff state variable recorder
    process : object with method update()
        The underlying's stochastic process evolver
    yield_curve : object with method discount()
        The yield curve for discounting future cash flows
    n_paths : int
        The number of paths to simulate
    """

    # Helper function to process one path
    def do_one_path(times):
        # move over each time node
        for t in times: 
            # generate one standard normal variable
            std_norm = np.random.normal() 
            # simulate the value of the underlying at the next time
            St = process.update(t, std_norm) 
            # pass the value of the underlying at time t to the option_state
            # so it can use it when computing the payoff
            option_state.update(t, St) 
        
        return option_state.calculate_payoff()


    # query for the times needed in the simulation
    times = option_state.times 
    # initialize the running sum to zero
    running_sum = 0 
    for i in range(n_paths):
        # reset process to start a new simulation
        process.reset()        
        # reset option state to a start a new simulation
        option_state.reset()
        # simulate one path
        path_value = do_one_path(times)
        # update the running sum with the payoff at the end of the path
        running_sum += path_value
    
    return running_sum / n_paths * yield_curve.discount(option_state.expiry)

### Auxiliary payoff and yield curve classes

We can reuse our `Call` and `FixedRateYieldCurve` classes:

In [2]:
import numpy as np

class Call:
    
    def __init__(self, strike):
        self.strike = strike

    def __call__(self, spot):
        return np.maximum(spot - self.strike, 0.)
        
class FixedRateYieldCurve:
    def __init__(self, rate):
        self.rate = rate
        
    def discount(self, maturity):
        return np.exp(-self.rate * maturity)

### Stochatic Process class

Let's now think how to implement a process class before looking into 
the more complex class for the option state.

The Monte-Carlo pricing function requires the process object to have:

- a `reset()` method to reset the process back to the start of the simulation 
  every time you do a new path
- an `update()` method to evolve the underlying to the next time in the simulation

#### A buggy implementation

So here an implementation that satisfy the requirements:

In [3]:
class BlackScholesProcess:
    
    def __init__(self, spot, rate, vol):  
        self.spot = spot
        self.rate = rate
        self.vol = vol   
        self.updated_spot = spot        
        self.updated_time = 0

    def __repr__(self):
        return "{}({!r}, {!r}, {!r}, {!r}, {!r})".format(
            self.__class__.__name__,
            self.spot, self.rate, self.vol, 
            self.updated_spot, self.updated_time
        )

    def __str__(self):
        return "{}(spot={}, rate={}, vol={}, updated_spot={}, updated_time={})".format(
            self.__class__.__name__,
            self.spot, self.rate, self.vol, 
            self.updated_spot, self.updated_time
        )
    
    def reset(self):
        self.updated_spot = self.spot        

    def update(self, time, norms):        
        self.updated_spot = self.spot * np.exp((self.rate - 0.5 * self.vol**2) * time + self.vol*np.sqrt(time)*norms)        
        return self.updated_spot

We create a simple function to do some basic tests so that we can reuse later on:

In [4]:
def test_black_scholes():
    p = BlackScholesProcess(100, 0.02, 0.2)
    print('p0:', p)
    st = p.update(1, 1.0)
    print('p1:', p)
    print('s1:', st)
    st = p.update(2, 0.0)
    print('p2:', p)
    print('s2:', st)

    p.reset()
    print('reset p:', p)

Let's try it:

In [5]:
test_black_scholes()

p0: BlackScholesProcess(spot=100, rate=0.02, vol=0.2, updated_spot=100, updated_time=0)
p1: BlackScholesProcess(spot=100, rate=0.02, vol=0.2, updated_spot=122.14027581601698, updated_time=0)
s1: 122.14027581601698
p2: BlackScholesProcess(spot=100, rate=0.02, vol=0.2, updated_spot=100.0, updated_time=0)
s2: 100.0
reset p: BlackScholesProcess(spot=100, rate=0.02, vol=0.2, updated_spot=100, updated_time=0)


It is running, but do you see any issues?

In fact, there are multiple bugs:

- On `print('p1:', p)` and `print('p2:', p)`, `updated_time` is still `0` even if we updated to time `1` and `2`.
This is caused by wrong implementation in `update` as it does not update the `updated_time`. To resolve that and avoid this type of errors, we will encapsulate both mutable attributes into a single class and update them with a constructor, that way they will always be updated together. The new class can be an inner class defined inside `BlackScholesProcess`.


- On `print('s2:', st)`, the simulated value is `100` at time `2` even if it was at `122.14...` at time `1` and we passed `0` for the brownian motion increment. We would expect to get the same value when doing the steps below

In [6]:
t1 = 1
t2 = 2
rate = 0.02
vol = 0.2
St1 = 122.14027581601698

## Manual calculations
ito = rate - 0.5 * vol * vol 
dt = t2 - t1
correct = St1 * np.exp(ito*dt)

## Moving from (t=1,St=St1) to (t=2,St=St1*Return) is the same as
## moving from (t=0, St=St1) to (t=1,St=St1*Return)
## because what matters is value of dt and Return
p = BlackScholesProcess(St1, rate, vol)
forward_simul = p.update(t1, 0)
print('correct:', correct)
print('forward:', forward_simul)

correct: 122.14027581601698
forward: 122.14027581601698


#### Fixing the implementation bugs

So the issue is that it does not simulate the increments from  $t_i$ to $t_{i+1}$ but rather simulates
from the initial spot at time $0$ to time `time`. We should do something like this instead 
`dt = time - self.updated_time`, then 
`self.updated_spot *= np.exp((self.rate - 0.5 * self.vol**2) * dt + self.vol*np.sqrt(dt)*norms)`
and finally also update the time with `self.updated_time = time`

In [7]:
class BlackScholesProcess:

    class StateData:
        
        def __init__(self, spot, time):
            self.spot = spot
            self.time = time
        
        def __repr__(self):
            return "{}({!r}, {!r})".format(
            self.__class__.__name__,
            self.spot, self.time
        )
   
    def __init__(self, spot, rate, vol):  
        self.spot = spot
        self.rate = rate
        self.vol = vol
        # a simulation path starts at time 0 from value spot
        self.state = self.StateData(spot, 0)
        
    def __repr__(self):
        return "{}({!r}, {!r}, {!r}, {!r})".format(
            self.__class__.__name__,
            self.spot, self.rate, self.vol, self.state
        )
    
    def reset(self):
        self.state = self.StateData(self.spot, 0)

    def update(self, time, norms):        
        # calculate time increment
        dt = time - self.state.time
        # simulate spot incrementally
        st = self.state.spot * np.exp((self.rate - 0.5 * self.vol**2) * dt + self.vol*np.sqrt(dt)*norms)
        # update state
        self.state = self.StateData(st, time)

        return st

Let's try it:

In [8]:
test_black_scholes()

p0: BlackScholesProcess(100, 0.02, 0.2, StateData(100, 0))
p1: BlackScholesProcess(100, 0.02, 0.2, StateData(122.14027581601698, 1))
s1: 122.14027581601698
p2: BlackScholesProcess(100, 0.02, 0.2, StateData(122.14027581601698, 2))
s2: 122.14027581601698
reset p: BlackScholesProcess(100, 0.02, 0.2, StateData(100, 0))


It works fine:

- we can create an object and its attributes are initialized correctly
- when we update, the object attributes are modified correctly
- when we reset, the object attributes are reset to their initial values

However, the implementation is not totally trivial. There were a few gotchas we fell
into while implementing it which revolved around the handling of the mutable data. It
is quite common to have bugs in code that mutates some data if multiple elements should
mutated together as it is quite easy to miss some of them and end up with objects in an
inconsistent state. We did handle that above by encapsulate the `updated_spot` and `updated_time`
in a single `StateData` class and we only used through the constructor and hence we are 
ensuring that both attributes are always updating together.

#### A superclass to encapsulate complex logic

Another technique we now introduce is using inheritance to have a superclass implement logic
that is common to the concept the class is meant to represent and let the subclasses implement
the specific part that is unique to the more specialized concept represented by the subclass.
This will allow us to achieve two objectives:

- maximize code reuse since the common part is implemented only once in the superclass `Process` even if 
we have many subclasses like `BlackScholesProcess`
- avoid some bugs by making the superclass maintain the type invariants

class invariant (or type invariant) is an invariant used for constraining objects of a class. 
Methods of the class should preserve the invariant. The class invariant constrains the state stored in the object.

Class invariants are established during construction and constantly maintained between calls to public methods.

Let is implement the superclass `Process`.

In [9]:
class Process:
    """
    A class used to simulate a stochastic process

    Inner classes
    -------------
    StateData : encapsulates the mutable state data (spot, time)

    Attributes
    ----------
    spot : float
        The starting value of the stochastic process
    state : Process.StateData
        The current value of the mutable state (spot and time)
    
    Methods
    -------
    reset()
        Reset the state of the object back to the initial state to start a new path
    evolve(spot, dt, norms)
        Simulate the next value of the spot for a time increment dt
    update(time, norms)
        Simulate the process to time using the standard normal variable norms
    """

    class StateData:
        """Mutable state of the process"""

        def __init__(self, spot, time):
            """
            Parameters
            ----------
            spot : float
                The current value of the stochastic process
            time : float
                The current time in the simulation
            """
            self.spot = spot
            self.time = time
        
        def __repr__(self):
            return "{}({!r}, {!r})".format(
            self.__class__.__name__,
            self.spot, self.time
        )
   
    def __init__(self, state):  
        """
        Parameters
        ----------
        state : Process.StateData
            The current value of the stochastic process
        """
        self.spot = state.spot
        # a simulation path starts at time 0 from value spot
        self.state = state
        
    def __repr__(self):
        return "{}({!r}, {!r})".format(
            self.__class__.__name__,
            self.spot, self.state
        )
    
    def reset(self):
        """Reset the state before starting a new path"""
        self.state = self.StateData(self.spot, 0)

    def evolve(self, spot, dt, norms):
        """
        Simulates the next value of the process for a time increment dt
        Should be implemented in subclasses to return the next spot value
        
        Parameters
        ----------
        spot : float
            The current value of the stochastic process
        dt: flost
            The time increment
        norms: float
            The standard normal variable to simulate a Brownian increment
        """
        raise NotImplementedError(
            "Classes derived from Process should implement evolve() method")

    def update(self, time, norms):        
        """
        Updates the process state to time `time`
        
        Parameters
        ----------
        time : float
            The time to evolve the process to
        norms: float
            The standard normal variable to simulate a Brownian increment
        """        
        # calculate time increment
        dt = time - self.state.time
        # simulate spot incrementally
        st = self.evolve(self.state.spot, dt, norms)
        # update state
        self.state = self.StateData(st, time)

        return st

The `Process` class implements most of the behavior, in particular it handles all the state
mutations and only delegates to the subclasses to implement their specific formula to 
generate $S_{t_{i+1}}$ given $S_{t_i}$ in a single function without any mutations.

In a sense, the class `Process` is not complete as it leaves the `evolve` method for subclasses
to implement. If we instantiate an instance of the class `Process` and invoke its `update` or 
`evolve` method, we will get an error:

In [10]:
import pytest
x = Process(Process.StateData(100, 0))
print(x)
with pytest.raises(NotImplementedError) as error:
    x.update(1, 0)
print('got error:', error)

Process(100, StateData(100, 0))
got error: <ExceptionInfo NotImplementedError('Classes derived from Process should implement evolve() method') tblen=3>


We used the `pytest` module which is a testing framework we will learn to use in a future section. 
It will allow to test our code more effectively and rigorously that what we have been doing so far.

#### Subclassing

We now implement the `BlackScholesProcess` as a subclass that implements the `evolve` method:

In [11]:
class BlackScholesProcess(Process):
    """A class used to simulate a Black-Scholes Geometric Brownian Motion process"""

    def __init__(self, spot, rate, vol):  
        """
        Parameters
        ----------
        spot : float

            The starting value of the stochastic process
        rate : float
            The risk-free rate for the risk-neutral drift
        vol : float
            The annualized instantaneous volatility
        """
        super().__init__(Process.StateData(spot, 0))
        self.rate = rate
        self.vol = vol
        
    def __repr__(self):
        return "{}({!r}, {!r}, {!r}, {!r})".format(
            self.__class__.__name__,
            self.spot, self.rate, self.vol, self.state
        )
    
    def evolve(self, spot, dt, norms):                
        st = spot * np.exp((self.rate - 0.5 * self.vol**2) * dt + self.vol*np.sqrt(dt)*norms)        
        return st

We can test it as before:

In [12]:
test_black_scholes()

p0: BlackScholesProcess(100, 0.02, 0.2, StateData(100, 0))
p1: BlackScholesProcess(100, 0.02, 0.2, StateData(122.14027581601698, 1))
s1: 122.14027581601698
p2: BlackScholesProcess(100, 0.02, 0.2, StateData(122.14027581601698, 2))
s2: 122.14027581601698
reset p: BlackScholesProcess(100, 0.02, 0.2, StateData(100, 0))


The subclass `BlackScholesProcess` inherits from the superclass `Process` 
because of the statement `class BlackScholesProcess(Process)`. 

Inheriting from the class `Process`, means the subclass `BlackScholesProcess` 
includes all the attributes and methods defined in `Process` unless it overrides
them with its own implementation.

In our example, `BlackScholesProcess` inherited the attributes `spot` and `state`, 
and the methods `reset` and `update`. 

It overrode the method `evolve` and `__repr__`, and the constructor `__init__`. 

Note also how the subclass constructor calls the superclass constructor with the 
syntax `super().__init__(Process.StateData(spot, 0))`.

Furthermore, notice how the overriden `evolve` method in the subclass `BlackScholesProcess`
ends being called by the superclass method `Process.update()` that is inherited. This
behavior is called __polymorphism__.

Finally, note how the subclass can add new attributes, for example the `rate` and `vol` attributes.

### OptionState class hierarchy

Similarly, we now implement a superclass `OptionState`:

In [13]:
class OptionState: 

    class StateData:

        def __init__(self, updates, underlying, state):
            self.updates = updates
            self.underlying = underlying
            self.state = state

        def __repr__(self):
            return "{}({!r}, {!r}, {!r})".format(
                self.__class__.__name__,
                self.updates, self.underlying, self.state)        


    def __init__(self, spot, expiry, payoff, times, state_data=None):
        self.spot = spot
        self.expiry = expiry
        self.payoff = payoff
        self.times = times
        self.state_data = state_data if state_data else self.StateData([], [], None)

    def reset(self):
        self.state_data = self.StateData([], [], None)    
    
    def __repr__(self):
        return ("{}({!r}, {!r}, {!r}, {!r}, {!r})".format(
            self.__class__.__name__,
            self.spot, self.expiry, self.payoff, self.times, self.state_data))

    def has_expired(self):
        if self.state_data.updates:
            return self.expiry <= max(self.state_data.updates)
        else:
            return False
    
    def update(self, time, underlying):
        if time in self.times and  time not in self.state_data.updates:
            self.state_data.updates.append(time)
            self.state_data.underlying.append(underlying)        
            self.state_data.state = self.update_state()        
    
    def update_state(self):
        if self.has_expired():
            return self.calculate_state_variable()
        else:
            return None
        
    def calculate_state_variable(self):
        raise NotImplementedError(
            "Classes derived from OptionState should implement calculate_state_variable() method")
    

    def calculate_payoff(self):
        if self.has_expired():
            return self.payoff(self.state_data.state)
        else:
            raise RuntimeError(
                "calculate_payoff() should only be called once option is expired")
    

Let's try it

In [14]:
x = OptionState(100, 5, lambda x: 2 * x, [1, 2, 3, 4, 5])
print(x)

assert not x.has_expired()

x.update(1.5, 200)
print(x)

with pytest.raises(RuntimeError) as error:
    x.calculate_payoff()
print('got error:', error)

with pytest.raises(RuntimeError) as error:
    x.calculate_state_variable()
print('got error:', error)

print(x.update_state())
print(x.update(1, 101.))

OptionState(100, 5, <function <lambda> at 0x7fad54453a60>, [1, 2, 3, 4, 5], StateData([], [], None))
OptionState(100, 5, <function <lambda> at 0x7fad54453a60>, [1, 2, 3, 4, 5], StateData([], [], None))
got error: <ExceptionInfo RuntimeError('calculate_payoff() should only be called once option is expired') tblen=2>
got error: <ExceptionInfo NotImplementedError('Classes derived from OptionState should implement calculate_state_variable() method') tblen=2>
None
None


All the complexity is in the superclass `OptionState` as it defines 
all the behavior expected from more specific `OptionState` class that 
will inherit from it. The subclasses only have to implement their
constructor and the `calculate_state_variable` method.

For example, to implement an Asian option, we define a subclass `AsianState`
inheriting from `OptionState`:

In [15]:
import statistics

class AsianState(OptionState): # AsianState inherits from OptionState

    def __init__(self, *args, **kargs):
        super().__init__(*args, **kargs)

    def calculate_state_variable(self):
        return statistics.mean(self.state_data.underlying)         

We now define a function to do some basic tests on instances of subclasses
of `OptionState`:

In [16]:
def test_option_state(x):
    print(x)

    assert not x.has_expired()

    x.update(1.5, 200)
    print(x)

    with pytest.raises(RuntimeError) as error:
        x.calculate_payoff()
    print('got error:', error)

    x.update(1, 101.)
    assert not x.has_expired()
    x.update(2, 102.)
    assert not x.has_expired()
    x.update(3, 103.)
    assert not x.has_expired()
    x.update(4, 104.)
    assert not x.has_expired()
    x.update(5, 105.)
    assert x.has_expired()

    print(x.calculate_payoff())

In [17]:
asian_state = AsianState(100, 5, lambda x: 2 * x, [1, 2, 3, 4, 5])
test_option_state(asian_state)

AsianState(100, 5, <function <lambda> at 0x7fad544078b0>, [1, 2, 3, 4, 5], StateData([], [], None))
AsianState(100, 5, <function <lambda> at 0x7fad544078b0>, [1, 2, 3, 4, 5], StateData([], [], None))
got error: <ExceptionInfo RuntimeError('calculate_payoff() should only be called once option is expired') tblen=2>
206.0


We can implement other path-dependent options as easily, for example:

In [18]:
class MaxState(OptionState):

    def __init__(self, *args, **kargs): 
        super().__init__(*args, **kargs)               

    def calculate_state_variable(self):
        return max(self.state_data.underlying)        

In [19]:
max_state = MaxState(100, 5, lambda x: x, [1, 2, 3, 4, 5])
test_option_state(max_state)

MaxState(100, 5, <function <lambda> at 0x7fad54407ca0>, [1, 2, 3, 4, 5], StateData([], [], None))
MaxState(100, 5, <function <lambda> at 0x7fad54407ca0>, [1, 2, 3, 4, 5], StateData([], [], None))
got error: <ExceptionInfo RuntimeError('calculate_payoff() should only be called once option is expired') tblen=2>
105.0


### Putting it together

Let's now use our classes to price an Asian option using the path-dependent pricing function:

In [20]:
T = 5.
S_0 = 100.
vol = 0.2
r = 0.02
K = 100
N = 10000

times = [1, 2, 3, 4, 5]
yield_curve = FixedRateYieldCurve(r)
rand_process = BlackScholesProcess(S_0, r, vol)
call_payoff = Call(K)
x = AsianState(S_0, T, call_payoff, times)

p = mc_path_dependent(x, rand_process, yield_curve, N)
p

14.1866392666187

We can also check our function also works for the `MaxState` subclass:

In [21]:
rand_process = BlackScholesProcess(S_0, r, vol)
y = MaxState(S_0, T, call_payoff, times)
l = mc_path_dependent(y, rand_process, yield_curve, N)
l

31.150778705556572

### Performance comparison

Let's compare the performance vs a simple vectorized implementation for Asian options:

In [22]:
from scipy import stats
import numpy as np
import math

def mc_asian_vectorized(dt, times, spot, strike, rate, vol, maturity, n_paths):    
    e = np.random.normal(size=(n_paths,times.size))
    dW_t = math.sqrt(dt) * e
    W_t = np.cumsum(dW_t, axis=1)

    St = spot * np.exp((rate-0.5*vol**2)*times + vol * W_t)
    A = np.mean(St, axis=1)
    Ct =  np.maximum(A - strike, 0) 
    return math.exp(-rate*maturity) * np.mean(Ct)

In [23]:
import time

T = 5.
S_0 = 100.
vol = 0.2
r = 0.02
K = 100
N = 10000

dt = 1/12

times = np.arange(0, T+dt, dt)

yield_curve = FixedRateYieldCurve(r)
rand_process = BlackScholesProcess(S_0, r, vol)
call_payoff = Call(K)
x = AsianState(S_0, T, call_payoff, times)

start = time.time()
p = mc_path_dependent(x, rand_process, yield_curve, N)
end = time.time()
print(p, end-start)

12.029982932940722 7.2420477867126465


In [24]:
start = time.time()
p = mc_asian_vectorized(dt, times, S_0, K, r, 0.2, T, N)
end = time.time()
print(p, end-start)

12.392244064721993 0.030224323272705078


## Conclusion

Although, we managed to design convenient and re-usable components through
classes and inheritance, our path-dependent pricing function ended up 
significantly slower because we did not make use of vectorization which is
critical for performance when doing numerical work with Python. 
In a future topic, we will re-visit this example to speed up the implementation.