In [7]:
import jax
import jax.numpy as jnp
import os
jax.config.update("jax_enable_x64", True)
os.environ['CUDA_VISIBLE_DEVICES'] = "1" # Titan and FP64 mode
jax.config.update('jax_platform_name', 'gpu')
print("JAX is using device:", jax.devices()[0], jax.devices())
from functools import partial
main_compute_device='gpu'

@partial(jax.jit, backend=main_compute_device, static_argnums=(1, 2))
def get_multiperiod_returns(X: jnp.ndarray, horizon: int, window_size: int):
    """
    """
    @partial(jax.jit, backend=main_compute_device, static_argnums=(1,))
    def _process(t: int, horizon: int):
        # Original code
        X_multi_s_temp = jnp.zeros(X.shape[2])
        for l in range(horizon):
            X_multi_s_temp += X[t+l, l, :]
        return X_multi_s_temp

        # @partial(jax.jit, backend=main_compute_device)
        # def _cumulative(l: int):
        #     return X[t+l, l, :] # is this indexing correct?

        # return jnp.sum(jax.vmap(_cumulative)(jnp.arange(horizon)), axis=0) # dim: (num_ptf)
        
    X_multi = jax.vmap(_process, in_axes=(0, None))(jnp.arange(window_size - horizon + 1), horizon)
    return X_multi # dim: (window_size - horizon + 1, num_ptf)

JAX is using device: cuda:0 [CudaDevice(id=0)]


In [5]:
X = jnp.arange(900).reshape((30, 10, 3))
x = get_multiperiod_returns(X, 5, 20)

In [8]:
y = get_multiperiod_returns(X, 5, 20)

In [10]:
x, y

(Array([[ 330,  335,  340],
        [ 480,  485,  490],
        [ 630,  635,  640],
        [ 780,  785,  790],
        [ 930,  935,  940],
        [1080, 1085, 1090],
        [1230, 1235, 1240],
        [1380, 1385, 1390],
        [1530, 1535, 1540],
        [1680, 1685, 1690],
        [1830, 1835, 1840],
        [1980, 1985, 1990],
        [2130, 2135, 2140],
        [2280, 2285, 2290],
        [2430, 2435, 2440],
        [2580, 2585, 2590]], dtype=int64),
 Array([[ 330.,  335.,  340.],
        [ 480.,  485.,  490.],
        [ 630.,  635.,  640.],
        [ 780.,  785.,  790.],
        [ 930.,  935.,  940.],
        [1080., 1085., 1090.],
        [1230., 1235., 1240.],
        [1380., 1385., 1390.],
        [1530., 1535., 1540.],
        [1680., 1685., 1690.],
        [1830., 1835., 1840.],
        [1980., 1985., 1990.],
        [2130., 2135., 2140.],
        [2280., 2285., 2290.],
        [2430., 2435., 2440.],
        [2580., 2585., 2590.]], dtype=float64))

In [6]:
x.shape

(16, 3)

In [5]:
T, N = 6, 3
X = jnp.arange(18).reshape((T, N))
Xbar = jnp.mean(X, axis=0)[..., None]
code = X.T @ jnp.ones((T, T)) @ X / (T ** 2)
paper = Xbar @ Xbar.T
code, paper

(Array([[56.25, 63.75, 71.25],
        [63.75, 72.25, 80.75],
        [71.25, 80.75, 90.25]], dtype=float64),
 Array([[56.25, 63.75, 71.25],
        [63.75, 72.25, 80.75],
        [71.25, 80.75, 90.25]], dtype=float64))

In [27]:
C = jnp.arange(30).reshape((3, 5, 2)) # Tx1xNxJ

print(C)

print() 
jnp.repeat(C, 2, axis=1).reshape(-1, 2)[:2]

[[[ 0  1]
  [ 2  3]
  [ 4  5]
  [ 6  7]
  [ 8  9]]

 [[10 11]
  [12 13]
  [14 15]
  [16 17]
  [18 19]]

 [[20 21]
  [22 23]
  [24 25]
  [26 27]
  [28 29]]]



Array([[0, 1],
       [0, 1]], dtype=int64)

In [2]:
x = jnp.arange(12).reshape((3, 2, 2))
y = jnp.arange(4).reshape((2, 2))
x, y, x @ y, x[2, :, :] @ y

: 

In [2]:
x = jnp.arange(1000).reshape((10, 10, 10))

In [3]:
x[:, jnp.newaxis].shape

(10, 1, 10, 10)

In [16]:
x, y, z = jnp.arange(10), jnp.zeros(10), jnp.full(10, 1)
assert x.shape == y.shape == z.shape == (10,)

jnp.array([x, y, z]).shape

(3, 10)

In [4]:
amt = 2
def func(i: int):
    # return jax.lax.dynamic_slice(x, (i, 0, 0), (amt, x.shape[1], x.shape[2])) # this works
    # # return x[:, :i] -> this doesn't work
    # return x[:, i] # this works
    return x[i], x[:, i]

def func3(i: int): # this doesn't work
    return jax.lax.dynamic_slice(x, (0, 0, 0), (i, 2, 2))

In [10]:
jax.vmap(jax.jit(func))(jnp.arange(4))[1].shape

(4, 10, 10)

In [39]:
x[2: 2 + 2]

Array([[[200, 201, 202, 203, 204, 205, 206, 207, 208, 209],
        [210, 211, 212, 213, 214, 215, 216, 217, 218, 219],
        [220, 221, 222, 223, 224, 225, 226, 227, 228, 229],
        [230, 231, 232, 233, 234, 235, 236, 237, 238, 239],
        [240, 241, 242, 243, 244, 245, 246, 247, 248, 249],
        [250, 251, 252, 253, 254, 255, 256, 257, 258, 259],
        [260, 261, 262, 263, 264, 265, 266, 267, 268, 269],
        [270, 271, 272, 273, 274, 275, 276, 277, 278, 279],
        [280, 281, 282, 283, 284, 285, 286, 287, 288, 289],
        [290, 291, 292, 293, 294, 295, 296, 297, 298, 299]],

       [[300, 301, 302, 303, 304, 305, 306, 307, 308, 309],
        [310, 311, 312, 313, 314, 315, 316, 317, 318, 319],
        [320, 321, 322, 323, 324, 325, 326, 327, 328, 329],
        [330, 331, 332, 333, 334, 335, 336, 337, 338, 339],
        [340, 341, 342, 343, 344, 345, 346, 347, 348, 349],
        [350, 351, 352, 353, 354, 355, 356, 357, 358, 359],
        [360, 361, 362, 363, 364, 365,

In [None]:
jax.vmap(func3)(jnp.arange(4))

In [76]:
def func2(x, y, z):
    return x * y * z

In [87]:
# Create some example arrays
x = jnp.arange(2)
y = jnp.arange(3)
z = jnp.arange(4)  
result = jax.vmap(
                jax.vmap(
                    jax.vmap(func2, in_axes=(0, None, None)), 
                          in_axes=(None, 0, None)), 
                  in_axes=(None, None, 0))(x, y, z)
result.T.shape

(2, 3, 4)

In [90]:
result.T[1, 2, 3]

Array(6, dtype=int64)

# Figuring out where multiperiod tensor is going wrong

In [42]:
old_dir_out = '/home/james/projects/tsfc/code/code_11092024/results_oos/multiperiod/char_anom/fig_onefit_oos_ret_rankptf_OLD/'
new_dir_out = '/home/james/projects/tsfc/code/code_11092024/results_oos/multiperiod/char_anom/fig_onefit_oos_ret_rankptf_ver1/'

In [43]:
old = jnp.load(old_dir_out + 'dict_tensor_oos.pkl', allow_pickle=True)

In [44]:
old[60]['approx'].shape


(424, 36, 1)

In [45]:
new = jnp.load(new_dir_out + 'dict_tensor_oos.pkl', allow_pickle=True)

In [46]:
new.keys()

dict_keys([60, 120])

In [41]:
old[60]['approx'][10, :, 0]

array([ 0.26586869, -0.1375646 , -0.11688133, -0.05469675, -0.11060408,
       -0.00474746,  0.05297104,  0.0812726 ,  0.06946345,  0.05595748,
        0.05070685,  0.06058645,  0.06319361,  0.0677459 ,  0.07374066,
        0.0906491 ,  0.07144239,  0.0853263 ,  0.09467011,  0.0985097 ,
        0.09526684,  0.09153597,  0.08991615,  0.0825991 ,  0.07421113,
        0.073383  ,  0.0689493 ,  0.07282168,  0.07114158,  0.07422516,
        0.07274161,  0.07263801,  0.07260948,  0.06824501,  0.06482834,
        0.06555689])

In [40]:
new[60][10, :, 2]

Array([ 0.20209107, -0.22513146, -0.13195244, -0.09189276, -0.14681543,
       -0.02986694,  0.03005314,  0.06473527,  0.05789063,  0.02526836,
        0.02014191,  0.03507391,  0.04802269,  0.05709412,  0.07550301,
        0.09473983,  0.07983973,  0.08493292,  0.09220442,  0.09344366,
        0.09566973,  0.09897536,  0.10259195,  0.08929108,  0.08310274,
        0.08547789,  0.08391242,  0.09076013,  0.0883622 ,  0.09477864,
        0.09725928,  0.10259348,  0.10706987,  0.10440189,  0.10658692,
        0.11039967], dtype=float64)

In [8]:
import pandas as pd
import numpy as np
from orig.tl_parafac_fix_intercept import parafac_fix_intercept
from orig.utils_tensor import get_normalized_factors
from tfm.parafac_jax import parafac_enhanced, normalize_factors
max_lag = 60
dir_input = '/home/james/projects/tsfc/code/code_11092024/organized_data/organized_data/char_anom'

In [9]:
X = jnp.load(f'{dir_input}/mat_ptf_re_lag_{max_lag}.npz')['mat_ptf_re_rank']
params = jnp.load(f'{dir_input}/dict_param_lag_{max_lag}.pkl', allow_pickle=True)
bin_labels, all_dates = params['lst_char'], params['all_dates']
T = len(all_dates)
start_date_oos = pd.to_datetime('2000-01-01')
start_date = all_dates[0]
dates_fit = all_dates[all_dates < start_date_oos]
dates_oos = all_dates[all_dates >= start_date_oos]
T_fit = len(dates_fit)
T_oos = len(dates_oos)


X_log = jnp.log(1 + X)
assert not jnp.isnan(X_log).any()

In [10]:
idx_window = 2
K = 3
window_size = 60

In [11]:
X_fit_old = X_log[idx_window:idx_window+window_size]

X_fit_new = jax.lax.dynamic_slice(X_log, 
        start_indices=(idx_window, 0, 0), slice_sizes=(window_size, X_log.shape[1], X.shape[2]))

In [12]:
jnp.array_equal(X_fit_new, X_fit_old)

Array(True, dtype=bool)

In [14]:
decomp= parafac_fix_intercept(X_fit_old,
        rank=K, 
        verbose=False,
        random_state=np.random.RandomState(100),
        return_errors=False,
        n_iter_max=100
        )

dict_fit=dict(zip(['F','W','B'], decomp.factors))
dict_fit['S']=decomp.weights
dict_fit = get_normalized_factors(dict_fit, reorder=True)
F, W, B, S=[dict_fit[key] for key in ['F','W','B','S']]

In [29]:
F_next=np.full((36, K), np.nan)
for idx_h in range(36):
    X_next=X_log[idx_window+window_size+idx_h]
    X_next_flatten=X_next.reshape(1,-1)
    Z_fit=np.full((K, X_log.shape[2] * 60),np.nan)
    for i in range(K):
        Z_fit[i]=np.kron(W[:,i], B[:,i]) * S[i]
    mat_weight_flatten=Z_fit.T@np.linalg.inv(Z_fit@Z_fit.T) # dim: (num_char*args.max_lag, K+1)
    F_next[idx_h,:]=X_next_flatten@mat_weight_flatten

In [30]:
F_next.shape

(36, 3)

In [17]:
weights, factors = parafac_enhanced(
    tensor=X_fit_new,
    rank=K,
    random_state=100,
    n_iter_max=100,
)

# Extract and normalize factors
factors = dict(zip(['F','W','B'], factors))
factors['S'] = weights
factors = normalize_factors(factors, reorder=True)
FF, WW, BB, SS=[dict_fit[key] for key in ['F','W','B','S']]

In [18]:
assert jnp.array_equal(FF, F)
assert jnp.array_equal(WW, W)
assert jnp.array_equal(BB, B)
assert jnp.array_equal(SS, S)

In [26]:
@jax.jit
def compute_Z_row(i: int):
    """Compute Z_fit for all K components at once"""
    return jnp.kron(factors['W'][:, i], factors['B'][:, i]) * factors['S'][i]

@jax.jit
def get_F_next(idx_h: int):
    """
    Get F_next for the next args.max_horizon periods by regressing X_next on tensor loadings
    """
    X_next_new = X_log[idx_window + window_size + idx_h] # this is fine
    X_next_flatten_new = X_next_new.reshape(1, -1)
    Z_fit_new = jax.vmap(compute_Z_row)(jnp.arange(K))
    mat_weight_flatten_new = Z_fit_new.T @ jnp.linalg.inv(Z_fit_new @ Z_fit_new.T) # dim: (num_char*args.max_lag, K+1)
    F_next_new = X_next_flatten_new @ mat_weight_flatten_new
    return jnp.squeeze(F_next_new, axis=0)

In [27]:
F_next_new = jax.vmap(get_F_next)(jnp.arange(36)) # dim: (args.max_horizon, K)
F_next_new.shape

(36, 3)

In [33]:
F_next[0, 0], F_next_new[0, 0]

(np.float64(-0.0439232686598585), Array(-0.04392327, dtype=float64))

In [34]:
jnp.allclose(F_next, F_next_new)

Array(True, dtype=bool)

In [35]:
FW=np.full((K, window_size, 60), np.nan)
for idx_k in range(K):
    FW[idx_k]=(F[:,idx_k][:,np.newaxis]@W[:,idx_k][np.newaxis,:]).cumsum(axis=1)

In [36]:
FW.shape

(3, 60, 60)

In [37]:
@jax.jit
def get_multiperiod_return(idx_k: int):
    """Get approximate multiperiod return in the window"""
    return jnp.cumsum(factors['F'][:, idx_k][:, jnp.newaxis] @ factors['W'][:, idx_k][jnp.newaxis, :], axis=1)

FW_new = jax.vmap(get_multiperiod_return)(jnp.arange(K)) # dim: (K, window_size, args.max_lag)

In [39]:
jnp.allclose(FW, FW_new)

Array(True, dtype=bool)

In [40]:
@jax.jit
def get_cov_approx(idx_s: int):
    """Approximate covariance matrix at all lookback times up to max lag."""
    return jnp.cov(FW_new[:, :, idx_s], bias=True)

# 1. Approximate calculation
mu_FW_approx_new = jnp.mean(FW_new, axis=1).T # dim: (args.max_lag, K)
cov_FW_approx_new = jax.vmap(get_cov_approx)(jnp.arange(max_lag)) # dim: (args.max_lag, K, K)

In [41]:
dict_mv = {}
mu_FW=np.mean(FW, axis=1).T # dim: (args.max_lag, K)
# note that FW has no time-series correlation in approx calc, because cumsum is over dimension W, not over F
cov_FW=np.full((60, K, K),np.nan)
for idx_s in range(60):
    s = idx_s+1
    # if args.fit_rx and idx_s==0:
    #     continue
    cov_FW[idx_s]=np.cov(FW[:,:,idx_s],bias=True)
dict_mv['approx'] = {'mu_FW':mu_FW, 'cov_FW':cov_FW}

In [43]:
jnp.allclose(mu_FW, mu_FW_approx_new), jnp.allclose(cov_FW, cov_FW_approx_new)

(Array(True, dtype=bool), Array(True, dtype=bool))

In [46]:
FW_next = (F_next*W[:36, :]).cumsum(axis=0) # is this right ????

In [45]:
FW_next_new = (F_next_new * factors['W'][:36, :]).cumsum(axis=0) # dim: (args.max_horizon, K)

In [47]:
jnp.allclose(FW_next, FW_next_new)

Array(True, dtype=bool)

In [48]:
       
mv_weight=np.full((36, K), np.nan)
for idx_s in range(36):
    # if args.fit_rx and idx_s==0:
    #     continue
    #print(calc_type, idx_s, cov_FW[idx_s])
    mv_weight[idx_s]=np.linalg.inv(cov_FW[idx_s])@mu_FW[idx_s]
# dict_mat_ret_tensor_oos[calc_type][idx_window, :, idx_K] =\
x = (FW_next * mv_weight[:36]).sum(axis=1) # dim: (args.max_horizon, )
x

array([-0.08145847,  0.13723988,  0.05761139,  0.17902112,  0.15863615,
        0.12708652,  0.17338235,  0.21906271,  0.21421067,  0.18467738,
        0.16657185,  0.16181734,  0.13734626,  0.14162367,  0.14923172,
        0.15327623,  0.14486802,  0.14530172,  0.13610477,  0.13588948,
        0.12961858,  0.12477082,  0.12377667,  0.12666162,  0.11505657,
        0.12126238,  0.12371712,  0.12352596,  0.11828099,  0.11398455,
        0.11068705,  0.10913654,  0.10564515,  0.10543698,  0.10311003,
        0.10283237])

In [49]:
@jax.jit
def get_multihorizon_and_ptf_returns():
    # Get OOS multi-horizon return FW - is this right?
    FW_next_new = (F_next * factors['W'][:36, :]).cumsum(axis=0) # dim: (args.max_horizon, K)

    def get_mv_weights(idx_s: int):
        if K > 1:
            return jnp.linalg.inv(cov_FW_approx_new[idx_s]) @ mu_FW_approx_new[idx_s]
        else:
            return jnp.mean(mu_FW_approx_new) / jnp.var(cov_FW_approx_new)
    
    mv_weights = jax.vmap(get_mv_weights)(jnp.arange(36)) # dim: (args.max_horizon, K)
    return (FW_next_new * mv_weights[:36]).sum(axis=1) # dim: (args.max_horizon)

ptf_returns = get_multihorizon_and_ptf_returns()

In [50]:
jnp.allclose(ptf_returns, x)

Array(True, dtype=bool)