In [1]:
import jax
import jax.numpy as jnp
import pickle 
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import rankdata
from tfm.utils._data import *
from tfm.utils._constants import *

jax.config.update("jax_enable_x64", True)
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
jax.config.update('jax_platform_name', 'gpu')
print("JAX is using device:", jax.devices()[0], jax.devices())

mp = {}

for dataset in ['scs', 'wrds']:

    if dataset == 'scs':
        data_path = '/home/james/projects/tsfc/code/code_11092024/data/tensor-data-SCS-ind-med-cap.npz'
        dir_out = '/home/james/projects/tsfc/code/code_11092024/organized_data/organized_data/scs'
    elif dataset == 'wrds':
        data_path = '/home/james/projects/tsfc/code/code_11092024/data/tensor-data-miss-wrds-full.npz'
        dir_out = '/home/james/projects/tsfc/code/code_11092024/organized_data/organized_data/wrds'

    ########################################################################
    ###### LOADING DATA, PREPROCESSING
    ########################################################################

    D = load_data(data_path, industry=dataset == 'scs')
    C, R = jnp.nan_to_num(D['C']).astype(dtype_compact), jnp.nan_to_num(D['R']).astype(dtype_compact)
    factor_names, idx_month = D['factor_names'], D['idx_month']
    C = jnp.squeeze(C, axis=1)

    if dataset == 'scs':
        # characteristic 32 (IPO) always leading to nans, also remove const factor
        print(f'Dropping factors {factor_names[32]} and {factor_names[33]}')
        C = C[:, :, :32]
        factor_names = factor_names[:32]
    elif dataset == 'wrds':
        print(f"Dropping {factor_names[-1]}")
        C = C[:, :, :-1]
        factor_names = factor_names[:-1]

    assert not jnp.isnan(C).any().item()

    def rank_norm_char(X):
        """
        Rank characteristics
        """    
        rank = rankdata(X, axis=0) / (X.shape[0] + 1)
        rank = rank - rank.mean(axis=0)
        rank = rank / jnp.abs(rank).sum(axis=0)
        return rank

    # Rank characterisics
    rank_c = np.zeros(C.shape)
    for t in jnp.arange(C.shape[0]):
        rank_c[t, :, :] = rank_norm_char(C[t, :, :])
    C = jnp.expand_dims(jnp.asarray(rank_c), axis=1)
    T = C.shape[0]

    # Aggregate daily returns into log returns
    returns = []
    for t in range(T):
        r = R[t, :, :]
        rt = jnp.exp(jnp.log(1 + r).sum(axis=0)) - 1
        rt = jnp.expand_dims(rt, axis=0)
        returns.append(rt)
    R = jnp.concatenate(returns, axis=0)

    # Construct factors and report annualized means and Sharpe Ratios
    C = jnp.squeeze(C, axis=1)
    F = jnp.squeeze(jnp.transpose(C, axes=(0, 2, 1)) @ R[..., None], axis=-1)
    means = jnp.mean(F, axis=0) * 12
    sharpes = (jnp.mean(F, axis=0) / jnp.std(F, axis=0)) * jnp.sqrt(12) 
    df = pd.DataFrame(data={"Factor Means": means, "Sharpe Ratios": sharpes}, index=factor_names)
    df.index.name = "Characteristic"
    mp[dataset] = df

JAX is using device: cuda:0 [CudaDevice(id=0)]
(594, 1, 9201, 34) (594, 23, 9201)
(594, 1, 1811, 34) (594, 23, 1811)
True
Dropping factors ipo and const
(552, 1, 12659, 107) (552, 23, 12659)
(552, 1, 1955, 107) (552, 23, 1955)
True
Dropping const


In [2]:
pd.set_option('display.max_rows', None)

In [3]:
import jax
import jax.numpy as jnp
import pandas as pd
# Read tensor data and parameters
dir_input = '/home/james/projects/tsfc/code/code_11092024/organized_data/organized_data/char_anom'
max_lag = 120
X = jnp.load(f'{dir_input}/mat_ptf_re_lag_{max_lag}.npz')['mat_ptf_re_rank'] # dim: (T, max_lag, num_ptf)
params = jnp.load(f'{dir_input}/dict_param_lag_{max_lag}.pkl', allow_pickle=True)
num_ptf = X.shape[-1]

bin_labels, _, _, max_lag, frac_longshort, all_dates, start_date_maxlag = params.values()
F = X[:, 0, :]
means = jnp.mean(F, axis=0) * 12
sharpes = (jnp.mean(F, axis=0) / jnp.std(F, axis=0)) * jnp.sqrt(12) 
df = pd.DataFrame(data={"Factor Means": means, "Sharpe Ratios": sharpes}, index=bin_labels)
df.index.name = "Characteristic"
mp['char_anom'] = df

In [9]:
data = []

# Iterate over each dataset and calculate the required statistics
for dataset in ['scs', 'wrds', 'char_anom']:
    avg_factor_mean = mp[dataset]['Factor Means'].mean()
    avg_sharpe_ratio = mp[dataset]['Sharpe Ratios'].mean()
    avg_abs_factor_mean = mp[dataset]['Factor Means'].abs().mean()
    avg_abs_sharpe_ratio = mp[dataset]['Sharpe Ratios'].abs().mean()
    
    # Append the results to the data list
    data.append({
        'Dataset': dataset,
        'Average Factor Means': avg_factor_mean,
        'Average Sharpe Ratio': avg_sharpe_ratio,
        'Average Absolute Factor Means': avg_abs_factor_mean,
        'Average Absolute Sharpe Ratios': avg_abs_sharpe_ratio
    })

# Create a DataFrame from the data list
summary_df = pd.DataFrame(data)

# Display the resulting DataFrame
summary_df

Unnamed: 0,Dataset,Average Factor Means,Average Sharpe Ratio,Average Absolute Factor Means,Average Absolute Sharpe Ratios
0,scs,-0.003539,-0.101743,0.017006,0.414933
1,wrds,0.012417,0.42203,0.015567,0.514396
2,char_anom,0.004715,0.06812,0.017669,0.35724


In [1]:
import jax.numpy as jnp

dict_tensor_oos = jnp.load('/home/james/projects/tsfc/code/code_11092024/results_oos/multiperiod/scs/tensor_fig_oos_ret_rankptf_ver3/dict_tensor_oos_1982.pkl', allow_pickle=True)


In [3]:
ret = dict_tensor_oos['TFM'][36]

In [9]:
ret.shape

(320, 36, 7)