In [150]:
import numpy as np
import pandas as pd
from abc import ABC, abstractmethod

class Forecast(ABC):
    def __init__(self, price_date: pd.DataFrame, lookback: int, horizon: int):
        self.price_data = price_date
        self.lookback = lookback
        self.horizon = horizon

    @abstractmethod
    def update(self):
        pass


class ABM(Forecast):
    """Arithmatic Brownian Motion"""
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    
    def update(self, t, tickers):
        price_data = self.price_data[tickers]
        ret_data = self.price_data[tickers].pct_change().dropna()
        rng = np.random.default_rng()
        sample_size = 10

        idx = ret_data.index.get_indexer([t], method="pad")[0]
        end_dt = idx
        start_dt = max(idx - self.lookback, 0)
        assert start_dt >= 0, "start_dt must be greater than 0"

        price_observations = ret_data.iloc[start_dt:end_dt]
        # numpy objects
        mu = np.mean(price_observations, axis=0)
        cov = np.cov(price_observations, rowvar=False) * 1e2

        periods = ret_data.iloc[idx:idx+self.horizon].index

        # create planning matrix
        mvn = rng.multivariate_normal(mu, cov, size=(sample_size, self.horizon))
        self.mvn_avg = pd.DataFrame(index=periods, data=mvn.mean(axis=0), columns=ret_data.columns)

        print((self.mvn_avg + 1).cumprod())
        self.price_est = price_data.loc[t].multiply((self.mvn_avg + 1).cumprod())
        self.price_est

        # for i, tau in enumerate(periods):
        #     # cumulative return for each rebal period (previous tau to current tau)
        #     if i == 0:
        #         self.ret_est[(periods[0], tau)] = ((self.mvn_avg.loc[periods[0] : tau] + 1).cumprod() - 1).iloc[-1]
        #     else:
        #         self.ret_est[(periods[0], tau)] = ((self.mvn_avg.loc[periods[i - 1] : tau] + 1).cumprod() - 1).iloc[-1]
    
class GBM(Forecast):
    """Geometric Brownian Motion"""
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def update(self, t, tickers):
        price_data = self.price_data[tickers]
        rng = np.random.default_rng()
        self.ret_est = {}

        idx = price_data.index.get_indexer([t], method="pad")[0]
        end_dt = idx
        start_dt = max(idx - self.lookback, 0)
        assert start_dt >= 0, "start_dt must be greater than 0"

        price_observations = price_data.iloc[start_dt:end_dt]
        # numpy objects
        mu = np.mean(price_observations, axis=0).to_numpy().reshape(-1, 1)
        # cov = np.cov(self.ret_obs, rowvar=False)
        sigma = np.std(price_observations, axis=0).to_numpy().reshape(-1, 1)
        # Parameters
        # time in years
        T = 1
        #number of steps
        n = 252
        # initial stock prices
        S0 = price_data.iloc[idx].to_numpy().reshape(-1, 1)
        # number of time steps
        dt = T/n
        # simulation using numpy arrays
        print(rng.normal(0, np.sqrt(dt), size=(self.horizon, len(tickers))).T.shape)
        print((mu - sigma ** 2 / 2) * dt
            + sigma * rng.normal(0, np.sqrt(dt), size=(self.horizon, len(tickers))).T)
        print(((mu - sigma ** 2 / 2) * dt).shape)
        St = np.exp(
            (mu - sigma ** 2 / 2) * dt
            + sigma * rng.normal(0, np.sqrt(dt), size=(self.horizon, len(tickers))).T
        )
        print(St.shape)
        return pd.DataFrame((S0 * St.cumprod(axis=0)).T, index=price_data.iloc[idx:idx+self.horizon].index, columns=price_data.columns)

In [151]:
df = pd.read_parquet("../raw_data/spx_stock_prices.parquet")
df

Unnamed: 0_level_0,JAVA_10078,ORCL_10104,MSFT_10107,SDS_10108,AYE_10137,TROW_10138,HON_10145,EMC_10147,BEAM_10225,LLTC_10299,...,CFN_92988,AVGO_93002,VRSK_93089,DG_93096,FTNT_93132,VAL_93159,GNRC_93246,QEP_93422,CBOE_93429,TSLA_93436
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1999-01-04,90.0625,43.0000,141.000,0.0,0.0,0.00,43.5625,87.0000,31.2500,0.0,...,0.0,0.00,0.00,0.00,0.00,0.0,0.00,0.0,0.00,0.00
1999-01-05,92.5625,44.3125,146.500,0.0,0.0,0.00,43.8125,91.1250,31.6250,0.0,...,0.0,0.00,0.00,0.00,0.00,0.0,0.00,0.0,0.00,0.00
1999-01-06,90.9375,46.3750,151.250,0.0,0.0,0.00,44.6875,89.5000,32.2500,0.0,...,0.0,0.00,0.00,0.00,0.00,0.0,0.00,0.0,0.00,0.00
1999-01-07,89.6875,45.6250,150.500,0.0,0.0,0.00,44.2500,93.7500,31.0625,0.0,...,0.0,0.00,0.00,0.00,0.00,0.0,0.00,0.0,0.00,0.00
1999-01-08,90.8750,46.2500,149.875,0.0,0.0,0.00,44.1875,93.5625,31.1250,0.0,...,0.0,0.00,0.00,0.00,0.00,0.0,0.00,0.0,0.00,0.00
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2022-03-25,0.0000,81.7300,303.680,0.0,0.0,148.92,197.7900,0.0000,0.0000,0.0,...,0.0,628.87,208.00,221.47,332.73,0.0,312.61,0.0,114.73,1010.64
2022-03-28,0.0000,83.6000,310.700,0.0,0.0,150.27,197.1100,0.0000,0.0000,0.0,...,0.0,632.88,211.60,226.30,338.53,0.0,312.33,0.0,112.94,1091.84
2022-03-29,0.0000,84.2900,315.410,0.0,0.0,155.00,197.5400,0.0000,0.0000,0.0,...,0.0,641.47,214.80,228.52,347.48,0.0,321.14,0.0,114.78,1099.57
2022-03-30,0.0000,83.3600,313.860,0.0,0.0,153.10,196.5400,0.0000,0.0000,0.0,...,0.0,631.09,216.97,227.47,341.93,0.0,311.06,0.0,114.69,1093.99


In [152]:
abm = ABM(price_date=df, lookback=252, horizon=90)

In [153]:
sub_tic = np.random.default_rng().choice(df.query("index >= '2019-01-01'").replace(0.0, np.nan).dropna(axis=1, how="any").columns, size=10, replace=False)
sub_tic


array(['SLB_14277', 'TROW_10138', 'GE_12060', 'KIM_77129', 'AVB_80381',
       'MTD_85621', 'PNC_60442', 'INTC_59328', 'DE_19350', 'APA_39490'],
      dtype=object)

In [154]:
pred = abm.update(t="2019-01-02", tickers=sub_tic)

            SLB_14277  TROW_10138  GE_12060  KIM_77129  AVB_80381  MTD_85621  \
date                                                                           
2019-01-02   0.934055    0.916665  0.943353   1.071430   1.018975   0.926160   
2019-01-03   0.876850    0.907217  0.804271   1.172367   0.992799   0.936696   
2019-01-04   0.803336    0.841768  0.706168   1.088977   0.970451   0.862969   
2019-01-07   0.797417    0.831746  0.652362   1.119817   1.032590   0.905665   
2019-01-08   0.814684    0.838728  0.605742   1.067083   1.019195   0.947445   
...               ...         ...       ...        ...        ...        ...   
2019-05-06   0.573054    0.578147  0.112583   0.673806   0.967495   0.386653   
2019-05-07   0.586699    0.595363  0.109189   0.699903   0.962882   0.377127   
2019-05-08   0.583393    0.656428  0.110009   0.731656   0.981053   0.401567   
2019-05-09   0.502195    0.590985  0.097897   0.715152   0.931681   0.361509   
2019-05-10   0.500738    0.625074  0.105

In [155]:
pred

Unnamed: 0_level_0,SLB_14277,TROW_10138,GE_12060,KIM_77129,AVB_80381,MTD_85621,PNC_60442,INTC_59328,DE_19350,APA_39490
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
2019-01-02,34.746832,84.186514,7.593990,15.621454,172.410497,505.720146,111.653695,48.667215,138.470788,21.878760
2019-01-03,32.618826,83.318777,6.474384,17.093111,167.981634,511.473359,115.669731,46.029298,128.167425,19.164526
2019-01-04,29.884112,77.307994,5.684650,15.877291,164.200363,471.215637,103.455804,37.938907,109.760114,15.682040
2019-01-07,29.663896,76.387597,5.251510,16.326926,174.714174,494.529440,107.780092,39.879096,119.741352,15.157371
2019-01-08,30.306241,77.028799,4.876227,15.558070,172.447821,517.342725,109.048847,39.039005,121.319377,16.443557
...,...,...,...,...,...,...,...,...,...,...
2019-05-06,21.317603,53.096988,0.906293,9.824098,163.700148,211.128105,65.481044,9.621033,24.671440,7.861833
2019-05-07,21.825218,54.678093,0.878971,10.204585,162.919580,205.926696,69.683102,9.775019,25.253364,7.670133
2019-05-08,21.702202,60.286332,0.885576,10.667546,165.994214,219.271888,74.610869,11.345208,28.289884,7.713309
2019-05-09,18.681649,54.276098,0.788071,10.426921,157.640399,197.398488,65.781334,9.110791,23.072441,6.240293


In [156]:
import plotly.express as px
px.line(pred)