In [None]:
from os.path import join
from itertools import accumulate
from functools import partial
from multiprocessing import Pool
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings('ignore')
from pca_factors_model import (
    cut_dates, read_residuals,
    get_pca_factor_model)
from utils import ou_fit
import matplotlib.pyplot as plt

In [None]:
N_YEARS = 1 / 12
window = pd.Timedelta(days=N_YEARS * 252)
N_FACTORS = 10
THRESH_RSQUARED = 0.75
LOW = 0.5
HIGH = 1.25
DIR_FACTORS_MODELS = './factors_models/'

# Data loading

In [None]:
prices = pd.read_parquet('prices_yf.parquet')
THRESH = 0.8
nulls = prices.isnull().mean(axis=0)
stocks_w_nans = nulls[nulls > THRESH].sort_values()
stocks_w_nans

In [None]:
cols = prices.columns
# stocks = cols[~cols.isin(stocks_w_nans.index)]
stocks = list(cols)
prices = (
    prices
    .loc[:, stocks]
#    .ffill(limit=2)
    .dropna()
)

In [None]:
univ_stacked = pd.concat([
    prices.stack(dropna=False).reset_index(),
    prices.pct_change().stack(dropna=False).reset_index()[0]  # column 0; not clean (dividends, splits, ...)
], axis=1)
univ_stacked.columns = ['date', 'id', 'price', 'chg']  # price is mid_price
CUT = '2007'  # '2006-09-19'
univ_stacked = univ_stacked.query(f'date >= {CUT}')
univ = univ_stacked.pivot(index='date', columns='id')
univ

In [None]:
returns = univ['chg'].iloc[1:]
returns

# Residuals returns, rolling PCA factor model

In [None]:
%%time
def get_pca_factor_model_rolling(date_end):
    dic = get_pca_factor_model(
        returns.dropna().loc[date_end - window:date_end],
        n_factors=N_FACTORS)
    for name, vals in dic.items():
        vals.reset_index().to_parquet(join(
            DIR_FACTORS_MODELS, f"{date_end.strftime('%y%m%d')}_{name}.parquet"))

with Pool() as pool:
    list(pool.imap_unordered(
        get_pca_factor_model_rolling, cut_dates(returns.index, window)))
    
spreads = (
    read_residuals(DIR_FACTORS_MODELS)
    .cumsum())
spreads

# Selection

In [None]:
%%time

def fit_ou_rolling(spreads, window):
    rolled = spreads.rolling(min_periods=2, window=window)
    keys = ['theta', 'score']
    return {
        key: rolled.apply(lambda x: ou_fit(x)[key], raw=True)
        for key in keys
    }

params = fit_ou_rolling(spreads, window)
mask_score = params['score'] >= THRESH_RSQUARED
(
    params['theta']
    .rolling(window)
    .mean()
    .where(mask_score)
    .stack()
    .groupby('id')
    .mean()
    .sort_values()
    .tail(5)
    .index
    .to_list()
)

In [None]:
def selection(params, window, n_stocks, thresh_rsquared):
    mask_score = params['score'] >= thresh_rsquared
    return (
        params['theta']
        .rolling(window)
        .mean()
        .where(mask_score)
        .stack()
        .groupby('id')
        .mean()
        .sort_values()
        .tail(n_stocks)
        .index
        .to_list()
    )

spreads = (
    read_residuals(DIR_FACTORS_MODELS)
    .cumsum())
params = fit_ou_rolling(spreads, window)
selected = selection(params, window, 3, THRESH_RSQUARED)
selected

In [None]:
def fit(spreads, window, n_stocks, thresh_rsquared):
    params = fit_ou_rolling(spreads, window)
    selected = selection(params, window, n_stocks, thresh_rsquared)
    spreads = spreads[selected]
    descs = spreads.describe()
    return selected, spreads, descs

fit(spreads, window, 3, THRESH_RSQUARED)[-1]

# Trading period

In [None]:
def trading_rule(cur_pos_spread, st_spread, **kwargs):
    assert 'low' in kwargs and 'high' in kwargs and kwargs['high'] > kwargs['low']
    new_pos_spread = 0
    if st_spread < -kwargs['high']:
        new_pos_spread = +1
    elif st_spread > +kwargs['high']:
        new_pos_spread = -1
    # st_spread in [-kwargs['low'], kwargs['high']]
    elif (
        kwargs['low'] <= np.abs(st_spread) <= kwargs['high'] and
        np.sign(st_spread) * cur_pos_spread == -1):
        new_pos_spread = cur_pos_spread
    return new_pos_spread

xxx = np.linspace(0, 3, 100)
yyy = 2 * np.sin(3 * xxx)

kwargs = dict(low=LOW, high=HIGH)
rule = partial(trading_rule, **kwargs)
pos_spread = accumulate(yyy, rule, initial=0)
pos_spread = list(pos_spread)[:-1]
plt.plot(xxx, pos_spread)
plt.plot(xxx, yyy)
plt.grid(True)

In [None]:
def trade_series(st_spread, trade_rule, **kwargs):
    rule = partial(trade_rule, **kwargs)
    pos = accumulate(st_spread.values, rule, initial=0)
    return pd.Series(index=st_spread.index, data=list(pos)[:-1])

xxx = np.linspace(0, 3, 100)
yyy = pd.Series(2 * np.sin(3 * xxx))

kwargs = dict(low=LOW, high=HIGH)
pos = trade_series(yyy, trading_rule, **kwargs)
plt.plot(xxx, pos)
plt.plot(xxx, yyy)
plt.grid(True)

In [None]:
spread = spreads.iloc[:, 0]
pos = trade_series(spread, trading_rule, low=LOW, high=HIGH)
spread.plot()
pos.mul(spread.max()).plot(grid=True)

In [None]:
selected, spreads, descs = fit(
    spreads, window, 3, THRESH_RSQUARED)
kwargs = dict(low=LOW, high=HIGH)
spreads.apply(partial(trade_series, trade_rule=trading_rule, **kwargs))

## On real data

In [None]:
CUT = '2022-05-01'
spreads_in = spreads.loc[:CUT]
spreads_out = spreads.loc[CUT:]

In [None]:
def trade(spreads_out, top_stocks, descs, trade_rule, **kwargs):
    st_spreads = (spreads_out[top_stocks] - descs.loc['mean', :]) / descs.loc['std', :]
    return st_spreads.apply(partial(trade_series, trade_rule=trade_rule, **kwargs))

top_stocks, spreads, descs = fit(spreads_in, window, 3, THRESH_RSQUARED)
pos_spreads = trade(spreads_out, top_stocks, descs, trading_rule, low=LOW, high=HIGH)

In [None]:
spreads.plot()

In [None]:
pos_spreads.plot()

In [None]:
def fit_n_trade(spreads_in, spreads_out, **kwargs):
    top_stocks, spreads, descs = fit(
        spreads_in, kwargs['window'], kwargs['n_stocks'], kwargs['thresh_rsquared'])
    return trade(
        spreads_out, top_stocks, descs, trading_rule,
        low=kwargs['low'], high=kwargs['high'])

In [None]:
%%timeit
kwargs = {
    'window': window,
    'n_stocks': 3,
    'thresh_rsquared': THRESH_RSQUARED,
    'low': LOW,
    'high': HIGH
}
_ = fit_n_trade(spreads_in, spreads_out, **kwargs)

# 253 ms ± 33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# Rolling

In [None]:
%%time

kwargs = {
    'window': window,
    'n_stocks': 3,
    'thresh_rsquared': THRESH_RSQUARED,
    'low': LOW,
    'high': HIGH
}

gcd = '1W'
splits = [
    ((end_date - window, end_date), (
        end_date + pd.Timedelta(days=1), end_date + pd.Timedelta(days=1)))
    for end_date in cut_dates(spreads.index, window)]

def fit_n_trade_split(split):
    spreads_in = spreads.loc[split[0][0]:split[0][1]]
    spreads_out = spreads.loc[split[1][0]:split[1][1]]
    return fit_n_trade(spreads_in, spreads_out, **kwargs)

with Pool() as pool:
    pos = list(pool.imap_unordered(fit_n_trade_split, splits))
    pos = [p for p in pos if p is not None]
    positions = pd.concat(pos).sort_index() if pos else None

# CPU times: user 328 ms, sys: 39.2 ms, total: 367 ms
# Wall time: 6.54 s

In [None]:
assert positions is not None

In [None]:
positions[positions.abs().gt(0)].count(axis=1).plot(grid=True)

In [None]:
positions.sum(axis=1).plot(grid=True)

In [None]:
positions.abs().sum(axis=1).plot(grid=True)

In [None]:
positions = (
    positions
    .div(positions.abs().sum(axis=1), axis=0)
    .fillna(0))
positions.abs().sum(axis=1).plot(grid=True)

In [None]:
positions.sum(axis=1).plot(grid=True)

In [None]:
positions.diff().abs().sum(axis=1).plot(grid=True)