<a href="https://colab.research.google.com/github/yuminlinsche/DependentDirichletProcessTF/blob/main/american_vanilla_binary_tree_new.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import sys
import functools
import inspect
import re
import time
import contextlib
import math
from tqdm import tqdm
from tqdm import trange
import numpy as np
import pandas as pd

import tensorflow as tf
import tensorflow_probability as tfp
tfd=tfp.distributions
tfb=tfp.bijectors
tfm=tf.math
tfl=tf.linalg
tfdtype=tf.float64

EPS=sys.float_info.epsilon

import plotly.express as px
import plotly.graph_objects as go
import plotly.figure_factory as ff

@contextlib.contextmanager
def time_here(msg=None, log_func=print):
    begin_time = time.time_ns()
    yield
    end_time=time.time_ns()
    time_elapsed = end_time - begin_time
    log_func(f"{msg or 'time_here'} | {time_elapsed/10**9} sec == {time_elapsed/10**6} millisec == {time_elapsed/10**3} microsec")

def timer_stat(func,iter_num=1000):
    @functools.warps(func)
    def time_wrapper(*args,**kwargs):
        t_=[]
        for _ in range(iter_num):
            start_time=time.time_ns()
            result=func(*args,**kwargs)
            end_time=time.time_ns()
            time_elapsed=(end_time-start_time)
            t_.append(time_elapsed)
        t_=np.array(t_)/10**9
    
        summary={
            'Avg':np.mean(t_),
            'Std':np.std(t_),
            'Max':np.max(t_),
            'Min':np.min(t_),
            }
        
        for key,value in summary.items():
            print(func.__name__+' '+key+' run time:'+str(value))
        
        return result,summary
    return time_wrapper

In [3]:
@tf.function()
def tensordot_apply_mask_to_data(data, mask, axis=0):
    '''
    # it keeps output data structure (not only the shape) as the same as the input data\n
    # mask is a tensor (dim>=2) describes linear calculations, e.g., differences\n    
    '''
    axis_tensor=tf.constant([axis])
    rank_data=tf.rank(data)
    data_dims=tf.range(rank_data)
    rank_data=rank_data[...,tf.newaxis]
    mask_last_dim=len(tf.shape(mask))-1
    tranpose_indices=tf.concat([[axis],data_dims[:axis],data_dims[axis+1:]],axis=0)[...,tf.newaxis]
    data_shape=tf.scatter_nd(indices=tranpose_indices,updates=data_dims,shape=rank_data)
    applied_data=tf.tensordot(a=mask,b=data,axes=[mask_last_dim,axis])
    output_data=tf.transpose(applied_data,perm=data_shape)
    return output_data

In [5]:
@tf.function()
def get_binarytree_american(initial_value,time_to_expiry,strike_price,
                            pricing_volatility,carry_cost,convinence_yield=0.,
                            option_type='PUT',
                            dividend_order=6,
                            is_output_dict=False,
                            dtype=tf.float64):
    s0=tf.convert_to_tensor(initial_value,dtype=dtype)
    tau=tf.convert_to_tensor(time_to_expiry,dtype=dtype)
    K=tf.convert_to_tensor(strike_price,dtype=dtype)
    volatility=tf.convert_to_tensor(pricing_volatility,dtype=dtype)
    carry_cost_rate=tf.constant(carry_cost,dtype=dtype)
    convinence_yield_rate=tf.constant(convinence_yield,dtype=dtype)
    step_num=2**dividend_order
    dt=tau/step_num
    sqrt_dt=tfm.sqrt(dt)
    time_num=step_num+1

    option_type_2_sign_table={
            'CALL':+1.0,
            'PUT':-1.0
        }
    option_type_sign=option_type_2_sign_table.get(option_type)
    payoff_func=lambda binary_tree_S_t,K: tf.linalg.band_part(tf.math.maximum(option_type_sign*binary_tree_S_t-option_type_sign*tf.convert_to_tensor(K,dtype=binary_tree_S_t.dtype)[...,tf.newaxis,tf.newaxis],0.),num_lower=-1,num_upper=0)


    base_mask=tf.linalg.band_part(tf.ones([time_num,time_num],dtype=dtype),num_lower=-1,num_upper=0)
    base_mask_1d=tf.linalg.diag_part(base_mask)
    volatility_mask=base_mask*volatility
    market_up_return_mask=tfm.exp(volatility_mask*sqrt_dt)
    market_down_return_mask=tfm.exp(-volatility_mask*sqrt_dt)


    carry_cost_rate_mask=base_mask_1d*carry_cost_rate
    convinence_yield_rate_mask=base_mask_1d*convinence_yield_rate

    opportunity_cost_mask=tfm.exp((carry_cost_rate_mask-convinence_yield_rate_mask)*dt)
    discount_factor_mask=tf.math.pow(opportunity_cost_mask,-1)
    market_up_prob_mask=tf.linalg.band_part((opportunity_cost_mask-market_down_return_mask)/(market_up_return_mask-market_down_return_mask),num_lower=-1,num_upper=0)
    market_down_prob_mask=tf.linalg.band_part(1.-market_up_prob_mask,num_lower=-1,num_upper=0)

    market_up_mask=tf.linalg.band_part(tfm.cumsum(base_mask,axis=0)-1,num_lower=-1,num_upper=0)    
    market_down_mask=tf.linalg.band_part(tfm.cumsum(base_mask,axis=1)-1,num_lower=-1,num_upper=0)

    filtration_up_prob_mask=tf.linalg.band_part(tfm.pow(market_up_prob_mask,market_up_mask),num_lower=-1,num_upper=0)  
    filtration_down_prob_mask=tf.linalg.band_part(tfm.pow(market_down_prob_mask,market_down_mask),num_lower=-1,num_upper=0)  
    filtration_prob_coefficient_mask=tf.linalg.band_part(base_mask*2.-tf.pad(tf.eye(base_mask.shape[-1]-1,dtype=base_mask.dtype),[[1,0],[1,0]],constant_values=1.),num_lower=-1,num_upper=0)
    filtration_prob_mask=filtration_prob_coefficient_mask*filtration_up_prob_mask*filtration_down_prob_mask

    market_return_mask=tf.linalg.band_part(tfm.exp((market_up_mask-market_down_mask)*volatility_mask*sqrt_dt),num_lower=-1,num_upper=0)

    market_price_mask=tf.linalg.band_part(s0*market_return_mask,num_lower=-1,num_upper=0)

    #exercising value
    option_value_exercising=payoff_func(market_price_mask,K)

    #continuing value
    cal_mask=tf.eye(time_num-1,dtype=dtype)#calculation mask
    cal_mask=tf.broadcast_to(cal_mask,[time_num-1,time_num-1,time_num-1])#calculation mask
    cal_mask=tf.pad(tf.linalg.set_diag(cal_mask,market_down_prob_mask[:-1,:-1]),[[0,0],[0,1],[1,0]])#calculation mask - insert down probability
    cal_mask=tf.linalg.set_diag(cal_mask,market_up_prob_mask[:-1])#calculation mask - insert up probability

    option_value_continuing=[option_value_exercising[-1]]
    #option_value=[terminal_value]
    for i in range(1,time_num):
        temp=tensordot_apply_mask_to_data(option_value_continuing[-i],cal_mask[-i],0)*discount_factor_mask[-i]#discounted step by step accumulatively
        #max_temp=tf.math.maximum(temp,option_value_exercising[-i-1])
        option_value_continuing=[temp]+option_value_continuing
        #option_value=[max_temp]+option_value
    option_value_continuing=tf.convert_to_tensor(option_value_continuing)
    option_value=tf.math.maximum(option_value_exercising,option_value_continuing)

    option_present_value=option_value[0,0]
    return option_present_value if is_output_dict is not True else {'options_present_value':option_present_value,
                                                                    'options_value':option_value,
                                                                    'options_continuing_value':option_value_continuing,
                                                                    'options_exercising_value':option_value_continuing,
                                                                    'options_type':option_type,
                                                                    'underlying_price':s0,
                                                                    'time_to_expiry':tau,
                                                                    'strike_price':K,
                                                                    'pricing_voltality':volatility,
                                                                    'carry_cost':carry_cost_rate,
                                                                    'convinence_yield':convinence_yield_rate}

In [21]:
s0=112.54
T=1.
K=100.
vol=0.123
ir=0.045
dr=0.00
with time_here():
    binary_tree_result=get_binarytree_american(initial_value=s0,time_to_expiry=T,strike_price=K,
                                pricing_volatility=vol,carry_cost=ir,convinence_yield=dr,
                                dividend_order=6,
                                option_type='PUT',is_output_dict=True)

time_here | 0.004171762 sec == 4.171762 millisec == 4171.762 microsec


In [22]:
px.imshow(binary_tree_result['options_value'].numpy().T,color_continuous_scale='jet')