In [1]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from config import data_path
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [2]:
# Load options data, original format
df_option = pd.read_parquet(os.path.join(data_path, 'SPY_options.parquet'))
df_option = df_option[df_option['EXPIRE_DATE'] != '2015-12-19']
df_option = df_option[df_option['EXPIRE_DATE'] != '2018-12-23']

In [3]:
# Load SPY price data
df_stock = pd.read_pickle(os.path.join(data_path, 'spy_1d.pkl'))
df_stock['ts'] = pd.to_datetime(df_stock['ts'].dt.date)
df_stock['log_return'] = np.log(df_stock['close_price']/df_stock['close_price'].shift(1))
df_stock['rv'] = df_stock['log_return'].ewm(span=21, adjust=False).std() * np.sqrt(252)
df_stock = df_stock[(df_stock['ts']>='2010-01-01')&(df_stock['ts']<='2022-12-31')].copy()

In [4]:
df_vix = pd.read_csv(os.path.join(data_path, 'VIX_History.csv'))
df_vix['ts'] = pd.to_datetime(df_vix['DATE'])
df_vix['vix'] = df_vix['CLOSE']/100
df_vix = df_vix[['ts','vix']]

In [5]:
df_sq = pd.read_csv(os.path.join(data_path, 'sqzme.csv'))
df_sq['ts'] = pd.to_datetime(df_sq['date'])
df_sq = df_sq[['ts','gex']]

In [6]:
df_label = pd.merge(df_stock, df_vix, on='ts', how='inner')
df_label = pd.merge(df_label, df_sq, on='ts', how='inner')
# Volatility Risk Premium (VRP = IV - RV)
df_label['VRP'] = df_label['vix'] - df_label['rv']
df_label.head()

Unnamed: 0,ts,close_price,log_return,rv,vix,gex,VRP
0,2011-05-02,136.220001,-0.00154,0.097546,0.1599,1897313000.0,0.062354
1,2011-05-03,135.729996,-0.003604,0.096693,0.167,1859731000.0,0.070307
2,2011-05-04,134.830002,-0.006653,0.099833,0.1708,1717764000.0,0.070967
3,2011-05-05,133.610001,-0.00909,0.105812,0.182,1361864000.0,0.076188
4,2011-05-06,134.199997,0.004406,0.10307,0.184,1490329000.0,0.08093


In [7]:
def select_atm_option(group):
    # First get the minimum STRIKE_DISTANCE_PCT for this date
    min_distance = group['STRIKE_DISTANCE_PCT'].min()
    atm_options = group[group['STRIKE_DISTANCE_PCT'] == min_distance]
    
    # Among ATM options, try to find those within 30-45 DTE
    valid_dte = atm_options[(atm_options['DTE'] >= 2) & (atm_options['DTE'] <= 14)]
    
    if len(valid_dte) > 0:
        # If we have options in range, select the one with minimum DTE
        return valid_dte.loc[valid_dte['DTE'].idxmin()]
    else:
        # If no options in range, find the one with DTE closest to target range
        atm_options['DTE_DIFF'] = atm_options['DTE'].apply(lambda x: min(abs(x-2), abs(x-14)))
        result = atm_options.loc[atm_options['DTE_DIFF'].idxmin()]
        return result.drop('DTE_DIFF')

# Apply the selection to get ATM options with desired DTE
df_option_atm = df_option.groupby('QUOTE_DATE', group_keys=False).apply(select_atm_option)

In [8]:
date_list = pd.Index(df_option_atm['QUOTE_DATE'].unique()).intersection(df_label['ts'])
df_option_atm = df_option_atm[df_option_atm['QUOTE_DATE'].isin(date_list)].reset_index(drop=True)

In [9]:
df_next_option = pd.DataFrame()

for i in range(len(date_list)-1):
    date = date_list[i]
    next_date = date_list[i+1]
    atm_option = df_option_atm[df_option_atm['QUOTE_DATE'] == date]
    
    next_day_option = df_option[
        (df_option['QUOTE_DATE'] == next_date) &
        (df_option['STRIKE'] == atm_option['STRIKE'].values[0]) &
        (df_option['EXPIRE_DATE'] == atm_option['EXPIRE_DATE'].values[0])
    ]
    
    if len(next_day_option) > 0:
        df_next_option = pd.concat([df_next_option, next_day_option])

df_next_option = df_next_option.reset_index(drop=True)

In [10]:
# Calculate low & high GEX thresholds
df_label['gex_low'] = df_label['gex'].rolling(window=21).quantile(0.05)
df_label['gex_high'] = df_label['gex'].rolling(window=21).quantile(0.95)

# Generate trade signal
def generate_signal(row):
    if row['gex'] < row['gex_low'] and row['VRP'] < 0:
        return 1
    elif row['gex'] > row['gex_high'] and row['VRP'] > 0:
        return -1
    else:
        return 0

df_label['signal'] = df_label.apply(generate_signal, axis=1)
df_label['signal'] = df_label['signal'].shift(1).fillna(0)

In [11]:
df_label['signal'].value_counts()

signal
 0.0    2747
-1.0     169
 1.0       9
Name: count, dtype: int64

In [12]:
for i in range(len(date_list)-1):
    date = date_list[i]
    next_date = date_list[i+1]
    atm_option = df_option_atm[df_option_atm['QUOTE_DATE'] == date]
    next_option = df_next_option[(df_next_option['QUOTE_DATE'] == next_date)]
    if len(next_option) > 0:
        flag = df_label[df_label['ts'] == date]['signal'].values[0]
        if flag == 1:
            enter_pnl = - 1 * (atm_option['C_ASK'].values[0] + atm_option['P_ASK'].values[0])
            exit_pnl = 1 * (next_option['C_BID'].values[0] + next_option['P_BID'].values[0])
        elif flag == -1:
            enter_pnl = 1 * (atm_option['C_BID'].values[0] + atm_option['P_BID'].values[0])
            exit_pnl = - 1 * (next_option['C_ASK'].values[0] + next_option['P_ASK'].values[0])
        else:
            enter_pnl, exit_pnl = 0, 0
    else:
        enter_pnl, exit_pnl = 0, 0

    pnl = exit_pnl + enter_pnl
    df_label.loc[df_label['ts'] == next_date, 'pnl'] = pnl

In [13]:
# Create a copy of the data for plotting
df = df_label[['ts', 'close_price', 'pnl', 'signal']].copy()

# Calculate cumulative PnL
df['cum_pnl'] = df['pnl'].cumsum()

# Create interactive plot with plotly
fig = make_subplots(specs=[[{"secondary_y": True}]])

# Add Cumulative PnL on primary y-axis
fig.add_trace(
    go.Scatter(
        x=df['ts'], 
        y=df['cum_pnl'], 
        name='Strategy PnL',
        line=dict(color='black', width=2)
    ),
    secondary_y=False
)

# Add SPY line on secondary y-axis
fig.add_trace(
    go.Scatter(
        x=df['ts'], 
        y=df['close_price'], 
        name='SPY',
        line=dict(color='gray', width=2)
    ),
    secondary_y=True
)

# Plot long straddle trades
long_days = df[df['signal'] == 1]
fig.add_trace(
    go.Scatter(
        x=long_days['ts'], 
        y=long_days['cum_pnl'], 
        mode='markers',
        name='Long Straddle',
        marker=dict(color='green', size=8)
    ),
    secondary_y=False
)

# Plot short straddle trades
short_days = df[df['signal'] == -1]
fig.add_trace(
    go.Scatter(
        x=short_days['ts'], 
        y=short_days['cum_pnl'], 
        mode='markers',
        name='Short Straddle',
        marker=dict(color='red', size=8)
    ),
    secondary_y=False
)

# Update layout
fig.update_layout(
    title="Straddle Strategy vs. SPY",
    xaxis_title="Date",
    legend=dict(x=1.05, y=0.99),
    hovermode="x unified",
    template="plotly_white",
    width=1000,
    height=600
)

# Update y-axes titles
fig.update_yaxes(title_text="Strategy PnL ($)", secondary_y=False)
fig.update_yaxes(title_text="SPY Price", secondary_y=True)

fig.show()