In [1]:
import pandas as pd
import numpy as np
from sklearn.mixture import GaussianMixture
from rbergomi import rb_pricing
from joblib import Parallel, delayed
import pickle
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
# !wget https://github.com/qmfin/option_data/raw/main/sp500_2017.json.bz2

In [3]:
df = pd.read_json('./sp500_2017.json.bz2', compression='bz2', orient='index')

In [4]:
# df

In [5]:
call_options = df.query('is_call==1')

In [6]:
call_m_t = call_options[['forward_price','tau','strike_price']].values

In [7]:
call_m_t = np.vstack([call_m_t[:,2]/call_m_t[:,0], call_m_t[:,1]]).T

In [8]:
# call_m_t

In [9]:
# plt.scatter(*call_m_t[0::100].T, marker='.')

In [10]:
call_gmm = GaussianMixture(n_components=64).fit(call_m_t)

In [11]:
# plt.scatter(*call_gmm.sample(10000)[0].T, marker='.')

In [12]:
put_options = df.query('is_call==-1')

In [13]:
put_m_t = put_options[['forward_price','tau','strike_price']].values

In [14]:
put_m_t = np.vstack([put_m_t[:,2]/put_m_t[:,0], put_m_t[:,1]]).T

In [15]:
# put_m_t

In [16]:
# plt.scatter(*put_m_t[0::100].T, marker='.')

In [17]:
put_gmm = GaussianMixture(n_components=64).fit(put_m_t)

In [18]:
# plt.scatter(*put_gmm.sample(10000)[0].T, marker='.')

In [19]:
# (v0, eta, rho , H) ∈ (0, 1] × (0, 5] × (-1,1) × (0, 0.5)
def gen_training_data(call_gmm, put_gmm, n_call=512, n_put=512):
    H = np.random.rand()*0.5
    rho = np.random.rand()*2-1
    eta = np.random.rand()*5
    v0 = np.random.rand()
    
    call = call_gmm.sample(n_call)[0]
    put = put_gmm.sample(n_put)[0]

    input_all = np.vstack([np.hstack([call, np.ones((n_call,1))]), 
                           np.hstack([put, -np.ones((n_put,1))])])
    S0 = 1.0
    data = rb_pricing(input_all, S0, H, rho, eta, v0)
    return data

In [20]:
n_thread = 8
n_sample = 32
n_files = 32 * 32 * 32
for _ in range(n_files):
    results = Parallel(n_jobs=n_thread, verbose=1)(delayed(gen_training_data)(call_gmm, put_gmm, 512, 512) for i in range(n_sample))
    pickle.dump(results, open('trn_%09d.pkl'%_, 'wb'))
    break

[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  32 out of  32 | elapsed:   23.3s finished
