In [1]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt

In [2]:
class ValidationError(Exception):
    pass


def check_order(*args):
    def validate(fields):
        for smaller, greater in zip(args[:-1], args[1:]):
            if fields[greater] < fields[smaller]:
                raise ValidationError("%s must not be smaller, than %s" % (greater, smaller))
    return validate


def range_validator(name, minimum, maximum):
    def validate(fields):
        if not minimum <= fields[name] <= maximum:
            raise ValidationError("%s must be between %s and %s" % (name, minimum, maximum))
    return validate


class Field(object):
    def __init__(self, name, default=0, help=None, hide=False):
        self.name = name
        self.default = default
        self.help = help or ''
        self.hide = hide


class Product(object):
    @classmethod
    def fields(cls):
        return []
    
    @classmethod
    def field_dict(cls):
        return {f.name: f for f in cls.fields()}

    def __init__(self, **kwargs):
        self.validate(kwargs)
        fields = {f.name: f for f in self.fields()}
        self.fields = {f.name: f.default for f in fields.values()}
        for name, value in kwargs.items():
            if name not in fields:
                raise ValueError("Field '%s' is not recognized" % name)
            self.fields[name] = value

    @classmethod
    def validators(cls):
        def check_positive(fields):
            for name, value in fields.items():
                if value <= 0:
                    raise ValidationError("%s must be positive" % name)
        return [check_positive]

    @classmethod
    def validate(cls, fields):
        for validator in cls.validators():
            validator(fields)

    def barriers(self):
        return []

    def payoff(self, spot):
        raise NotImplementedError


class Forward(Product):
    @classmethod
    def fields(cls):
        return [Field('price', default=100)]

    def payoff(self, spot):
        return spot - self.fields['price']


class VanillaOption(Product):
    @classmethod
    def fields(cls):
        return [Field('strike', default=100)]


class Call(VanillaOption):
    def payoff(self, spot):
        return np.maximum(0.0, spot - self.fields['strike'])


class Put(VanillaOption):
    def payoff(self, spot):
        return np.maximum(0.0, self.fields['strike'] - spot)


class LeveragedProduct(Product):
    @classmethod
    def fields(cls):
        return [Field('leverage_ratio', default=1, hide=True)]

    @classmethod
    def validators(cls):
        return super().validators() + [range_validator('leverage_ratio', 0, 2)]
    
    def leverage(self):
        return self.fields['leverage_ratio']


class CallSpread(LeveragedProduct):
    @classmethod
    def fields(cls):
        return [Field('long_call', default=100), Field('short_call', default=120)] + super().fields()

    @classmethod
    def validators(cls):
        return super().validators() + [check_order('long_call', 'short_call')]

    def payoff(self, spot):
        return Call(strike=self.fields['long_call']).payoff(spot) \
            - self.leverage() * Call(strike=self.fields['short_call']).payoff(spot)


class PutSpread(LeveragedProduct):
    @classmethod
    def fields(cls):
        return [Field('long_put', default=120), Field('short_put', default=100)] + super().fields()

    @classmethod
    def validators(cls):
        return super().validators() + [check_order('short_put', 'long_put')]

    def payoff(self, spot):
        return Put(strike=self.fields['long_put']).payoff(spot) \
            - self.leverage() * Put(strike=self.fields['short_put']).payoff(spot)


class Collar(Product):
    @classmethod
    def fields(cls):
        return [Field('call_strike', default=120), Field('put_strike', default=100)]

    @classmethod
    def validators(cls):
        return super().validators() + [check_order('put_strike', 'call_strike')]

    def payoff(self, spot):
        return Call(strike=self.fields['call_strike']).payoff(spot) \
            - Put(strike=self.fields['put_strike']).payoff(spot)


class Seagull(Product):
    @classmethod
    def fields(cls):
        return [Field('short_put', default=100), Field('long_call', default=110), Field('upper_strike', default=120)]

    @classmethod
    def validators(cls):
        return super().validators() + [check_order('short_put', 'long_call', 'upper_strike')]

    def payoff(self, spot):
        return Call(strike=self.fields['long_call']).payoff(spot) \
            - Put(strike=self.fields['short_put']).payoff(spot) \
            - Call(strike=self.fields['upper_strike']).payoff(spot)


class EnhancedForward(Product):
    @classmethod
    def fields(cls):
        return [Field('strike', default=100), Field('upper_strike', default=120)]

    @classmethod
    def validators(cls):
        return super().validators() + [check_order('strike', 'upper_strike')]

    def payoff(self, spot):
        return Forward(price=self.fields['strike']).payoff(spot) \
            - Call(strike=self.fields['upper_strike']).payoff(spot)


class BarrierOption(Product):
    @classmethod
    def barrier_names(cls):
        return {'barrier'}

    def barriers(self):
        barrier_names = self.barrier_names()
        return [value for name, value in self.fields.items() if name in barrier_names]


class KOForward(LeveragedProduct, BarrierOption):
    @classmethod
    def fields(cls):
        return [Field('strike', default=100), Field('barrier', default=120)] + super().fields()

    @classmethod
    def validators(cls):
        return super().validators() + [check_order('strike', 'barrier')]

    def payoff(self, spot):
        return (Call(strike=self.fields['strike']).payoff(spot)
                - self.leverage() * Put(strike=self.fields['strike']).payoff(spot)) * (
            spot < self.fields['barrier'])


class ForwardExtra(BarrierOption):
    @classmethod
    def fields(cls):
        return [Field('strike', default=120), Field('barrier', default=100)]

    @classmethod
    def validators(cls):
        return super().validators() + [check_order('barrier', 'strike')]

    def payoff(self, spot):
        return Forward(price=self.fields['strike']).payoff(spot) \
            * np.logical_or(spot < self.fields['barrier'], self.fields['strike'] < spot)


class CollarExtra(BarrierOption):
    @classmethod
    def fields(cls):
        return [Field('call_strike', default=120), Field('put_strike', default=110), Field('barrier', default=100)]

    @classmethod
    def validators(cls):
        return super().validators() + [check_order('barrier', 'put_strike', 'call_strike')]

    def payoff(self, spot):
        return Collar(call_strike=self.fields['call_strike'], put_strike=self.fields['put_strike']).payoff(spot) \
            * np.logical_or(spot < self.fields['barrier'], self.fields['put_strike'] < spot)

In [3]:
from sys import stderr

import ipywidgets as widgets
from IPython.display import display

products = [Forward, Collar, Call, Put, CallSpread, PutSpread, Seagull, EnhancedForward, KOForward, ForwardExtra, CollarExtra]
products = {p.__name__: p for p in products}

plt.style.use('seaborn-whitegrid')

PLOT_SAMPLES = 1000


def plot_main(ax, x, y, is_dotted=False):
    blue = plt.rcParams['axes.prop_cycle'].by_key()['color'][0]
    ax.plot(x, y, lw=3, linestyle=':' if is_dotted else '-', color=blue, zorder=1000, label="Payoff")

def plot_payoff(product, size, **kwargs):
    try:
        product = product(**kwargs)
    except ValidationError as ex:
        stderr.write('%s\n'% ex)
        return

    field_dict = product.field_dict()
    min_value = max(kwargs.values())
    max_value = 0
    for key, value in kwargs.items():
        if field_dict[key].hide:
            continue
        min_value = min(min_value, value)
        max_value = max(max_value, value)

    spots = np.linspace(min_value - size, max_value + size, PLOT_SAMPLES)

    payoff_values = product.payoff(spots)

    fig = plt.figure(figsize=(12, 12))
    ax = plt.axes()

    for key, value in kwargs.items():
        if field_dict[key].hide:
            continue
        ax.axvline(x=value, color='black', linestyle='--', alpha=0.5, label=key, lw=2)
        ax.text(value, max(payoff_values), '%s = %s ' % (key, value),
                rotation=90, va='top', ha='right', fontsize='x-large', color='black', alpha=0.5)

    # Draw discontinuities with dashed lines
    blue = plt.rcParams['axes.prop_cycle'].by_key()['color'][0]
    last_idx = 0
    for barrier in product.barriers():
        barrier_lower_idx = np.max(np.argwhere(spots < barrier))
        barrier_upper_idx = np.min(np.argwhere(spots > barrier))
        assert last_idx <= barrier_lower_idx
        plot_main(ax, spots[last_idx:barrier_lower_idx], payoff_values[last_idx:barrier_lower_idx])
        plot_main(ax, [barrier, barrier], [payoff_values[barrier_lower_idx], payoff_values[barrier_upper_idx]],
                  is_dotted=True)
        last_idx = barrier_upper_idx
    plot_main(ax, spots[last_idx:], payoff_values[last_idx:])

    ax.set_aspect('equal')
    ax.set_xlim(min(spots), max(spots))
    ax.tick_params(labelsize='x-large')

    plt.title('Payoff function for a %s option' % product.__class__.__name__, size='xx-large')
    ax.set_xlabel('Spot price at maturity', size='x-large')
    ax.set_ylabel('Payoff', size='x-large')

    plt.show()

def show_widgets():
    setup = True
    field_widget = []

    def print_price(**kwargs):
        product = product_widget.value
        plot_payoff(product, 20, **kwargs)

    def product_changed(product):
        fields = {f.name: widgets.FloatText(description=f.name, value=f.default) for f in product.fields()}
        if field_widget:
            field_widget[0].close()
        new_i = widgets.interactive(print_price, **fields)

        display(new_i)
        field_widget[:] = [new_i]

    product_widget = widgets.Select(options=products, description='Product:')

    product_i = widgets.interactive(product_changed, product=product_widget)

    display(product_i)

    setup = False

show_widgets()