# Stock Embedding 

This project notebook is based on paper _Learning Embedded Representation of the [Stock Correlation Matrix using Graph Machine Learning_](https://arxiv.org/pdf/2207.07183.pdf) 

In [1]:
# setup environment
from stock_embedding import StockEmbedding

In [2]:
# directly fetch stock price and prepare return data from yfinance by providing start and end date
# GICS related info is fetched directly from wikipedia
se = StockEmbedding(start_date='2021-01-01', end_date='2022-08-04')

[*********************100%***********************]  20 of 20 completed
[*********************100%***********************]  20 of 20 completed
[*********************100%***********************]  20 of 20 completed
[*********************100%***********************]  20 of 20 completed
[*********************100%***********************]  20 of 20 completed
[*********************100%***********************]  20 of 20 completed
[*********************100%***********************]  20 of 20 completed
[*********************100%***********************]  20 of 20 completed
[*********************100%***********************]  20 of 20 completed
[*********************100%***********************]  20 of 20 completed
[*********************100%***********************]  20 of 20 completed
[*********************100%***********************]  20 of 20 completed
[*********************100%***********************]  20 of 20 completed
[*********************100%***********************]  20 of 20 completed
[*****

  return csr_matrix(np.triu(np.sqrt(2 * (1 - self.corr_mat))))


In [3]:
se.save_return_df()

## Hyper Parameter Tuning Based on V-Measure
We have the following parameter to tune
* r (int): number of random walk from each node in the network
* l (int): length for each random walk from each node in the network
* p (float): probability of a random walk will return to the node it visited previously
* q (float): probability of a random walk will explore the unexplored part of the graph
* vector_size (int): embedding size, dim in the paper
* window (int): Maximum distance between the current and predicted word within sentence. w in the paper*

In [4]:
param_lst = [{'l': 50, 'r': 10, 'p': 0.5, 'q': 2, 'w': 5, 'dim': 16},
                {'l': 100, 'r': 50, 'p': 2, 'q': 0.5, 'w': 5, 'dim': 16},
                {'l': 200, 'r': 10, 'p': 2, 'q': 0.5, 'w': 5, 'dim': 32}]

summary, opt_param = se.hyperparam_tuning(param_lst)

In [5]:
# similar to Table 3 with additional GICS industry
summary

Unnamed: 0,r,l,p,q,dim,w,sector,group,industry,subindustry,average
0,10,50,0.5,2.0,16,5,0.403711,0.607894,0.717576,0.816442,0.636406
1,50,100,2.0,0.5,16,5,0.463947,0.632105,0.731896,0.82062,0.662142
2,10,200,2.0,0.5,32,5,0.483616,0.622592,0.730988,0.821894,0.664773


In [13]:
# optimal hyperparameter
print(opt_param)
se.learn_embedding(r=opt_param['r'],
                   l=opt_param['l'],
                   p=opt_param['p'],
                   q=opt_param['q'],
                   vector_size=opt_param['dim'],
                   window=opt_param['w'])

{'l': 200, 'r': 10, 'p': 2, 'q': 0.5, 'w': 5, 'dim': 32}


{'A': array([-0.7238251 , -0.16119963, -0.741202  , -0.42152733,  0.7228185 ,
         0.44488317,  1.0018129 ,  0.27913818,  0.7947404 , -0.7703384 ,
         1.1945633 , -0.04896798, -0.8384862 , -0.8193982 , -1.0136195 ,
        -0.03663795,  0.38047937, -0.81301117,  0.33357424,  0.36251083,
        -1.0460135 ,  0.48991445, -0.87486166, -0.12089798, -1.2134959 ,
         0.27891377,  0.7432166 ,  0.26325884, -0.54440975,  0.3908448 ,
        -0.0371459 ,  0.6995408 ], dtype=float32),
 'AAL': array([ 0.06458579,  0.5156098 , -0.8652317 , -0.5805373 ,  0.00410277,
         1.2399111 ,  1.0461245 , -1.5775771 , -0.25471032,  0.1920704 ,
         0.57260114, -0.2805938 , -0.11760406,  0.06924746, -0.5129515 ,
         0.5372906 ,  0.2505066 , -1.1261193 ,  1.123684  ,  0.36512932,
        -0.5923059 ,  1.1733475 ,  0.46941003, -0.26787785, -0.42738885,
        -0.1754152 , -0.815444  , -0.292768  , -0.27769288,  0.09555124,
         0.461091  ,  0.9067853 ], dtype=float32),
 'AAP': ar

## Example of Qualitative Evaluation
* Stock Similarity: find top similar stock based on Cosine distance
* Analogical Inference: answer question like "JPM is to GS as JNJ is to ?"
* Not Match Stock: Answer questions like: Does not match from JPM, MS, GS, GOOGL: GOOGL
* Find Similar Stock for Given Stocks: answer questions like Most similar to GOOGL given JNJ, MS, MOS, META: META

In [14]:
# stock similarity
se.find_similar_stocks(ticker='JPM')

Unnamed: 0,ticker,distance,sector,group,industry
0,MS,0.022767,Financials,Diversified Financials,Capital Markets
1,C,0.024446,Financials,Banks,Banks
2,WFC,0.024905,Financials,Banks,Banks
3,BAC,0.026931,Financials,Banks,Banks
4,GS,0.02923,Financials,Diversified Financials,Capital Markets
5,CFG,0.072694,Financials,Banks,Banks
6,FITB,0.077486,Financials,Banks,Banks
7,COF,0.093276,Financials,Diversified Financials,Consumer Finance
8,SIVB,0.096844,Financials,Banks,Banks
9,CMA,0.097389,Financials,Banks,Banks


In [17]:
# analogical inference
se.analogical_inference(ticker='AAPL', ticker_1='JPM', ticker_2='GS')

JPM is to GS as AAPL is to AMZN


('AMZN', 0.9349383)

In [16]:
se.analogical_inference(ticker='JNJ', ticker_1='JPM', ticker_2='GS')

JPM is to GS as JNJ is to GILD


('GILD', 0.8816782)

In [18]:
# not match stock
se.identify_not_match_stock(['JPM', 'MS', 'GS', 'GOOGL'])

Does not match from JPM, MS, GS, GOOGL: GOOGL


'GOOGL'

In [19]:
se.identify_not_match_stock(['JNJ', 'BMY', 'PFE', 'HD'])

Does not match from JNJ, BMY, PFE, HD: HD


'HD'

In [20]:
se.identify_not_match_stock(['UAL', 'AAL', 'DAL', 'TSLA'])

Does not match from UAL, AAL, DAL, TSLA: TSLA


'TSLA'

In [21]:
# Find Similar Stock for Given Stocks
se.identify_similar_stock(ticker='GOOGL', tickers=['JNJ', 'MS', 'MOS', 'META'])

Most similar to GOOGL given JNJ, MS, MOS, META: META


'META'

In [22]:
se.identify_similar_stock(ticker='BLK', tickers=['TSLA', 'STT', 'JNJ', 'AAPL'])

Most similar to BLK given TSLA, STT, JNJ, AAPL: STT


'STT'

In [23]:
se.identify_similar_stock(ticker='WMT', tickers=['CVS', 'COST', 'JNJ', 'MSFT'])

Most similar to WMT given CVS, COST, JNJ, MSFT: COST


'COST'