In [34]:
import pandas as pd
import numpy as np
import scipy.optimize as opt

from statsmodels.api import add_constant
from numpy.linalg import lstsq, eigh, eigvalsh
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from scipy.optimize import minimize
from matplotlib import pyplot as plt

from skglm import GeneralizedLinearEstimator
from skglm.penalties import SCAD

import jax, jax.numpy as jnp
from jaxopt import LBFGS

from numba import njit, jit
import optax

In [35]:
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

In [36]:
G = 3
R = 3 # NOTE number of factors
GF = np.array((1, 1, 1))

In [37]:
# NOTE this objective value is always without individual effects
def objective_value_without_individual_effects(y, x, beta, alpha, kappa, N, R):
    y = np.squeeze(y, axis=2) # FIXME preferably this should be done outside
    # base = ((y - np.sum(x * beta.T[:, None, :], axis=2)) ** 2).mean()
    res = (y - np.sum(x * beta.T[:, None, :], axis=2))
    display(res)
    raise NotImplementedError("This function is not implemented yet.")
    penalty = np.mean(np.prod(np.linalg.norm(beta[:, :, None] - alpha[:, None, :], axis=0), axis=1)) * kappa
    return base + penalty

In [38]:
# NOTE for some reason this is slower than the non-jnp version
@jax.jit
def jnp_objective_value(y, x, beta, alpha, mu, kappa):
    base = ((y - jnp.sum(x * beta.T[:, None, :], axis=2) - mu) ** 2).mean()
    penalty = jnp.mean(jnp.prod(jnp.linalg.norm(beta[:, :, None] - alpha[:, None, :], axis=0), axis=1)) * kappa
    return base + penalty

In [39]:
def unpack(theta):
    beta = theta[:K * N].reshape(K, N)
    mu = theta[K * N:K * N + N].reshape(N, 1)
    alpha = theta[K * N + N:].reshape(K, G)
    return beta, mu, alpha

def pack(beta, mu, alpha):
    return np.concatenate((beta.flatten(), mu.flatten(), alpha.flatten()), axis=0)

In [40]:
# @njit
def obj(theta, kappa=1):
    beta, mu, alpha = unpack(theta)
    return objective_value(y, x, beta, alpha, mu, kappa)

In [41]:
# Get Initial
# TODO think of better initialization technique
def _generate_initial_estimates(y, x, N, T, K, G):
    beta_init = np.zeros_like(beta)

    for i in range(N):
        beta_init[:, i:i+1] = lstsq(x[i].reshape(T, K), y[i].reshape(T, 1))[0]
    alpha_init = KMeans(n_clusters=G).fit(beta_init.T).cluster_centers_.T

    for j in range(G):
        if (np.abs(beta_init.T - alpha_init[:, j]).min() < 1e-2):
            alpha_init[:, j] += 1e-1 * np.sign(alpha_init[:, j])

    mu_init = np.mean(y, axis=1)

    return beta_init, alpha_init, mu_init

In [42]:
import pickle

example = pickle.load(open("dgp3-example-2.pkl", "rb"))

In [43]:
x_example = example[0]
y_example = np.reshape(example[1], (30, 10, 1))
g_example = example[2]
alpha_example = example[3]
beta_example = example[4]
beta_example

array([[1, 1, 1],
       [2, 2, 2],
       [3, 3, 3]])

In [44]:
y_example.shape

(30, 10, 1)

In [45]:
x = x_example
y = y_example
N, T, K = x.shape
G = 3

In [46]:
y = np.float32(y)
x = np.float32(x)

In [47]:
from joblib import Memory

memory = Memory(location=None, verbose=0)

In [48]:
# NOTE this objective value is always without individual effects
def objective_value_without_individual_effects(y, x, beta, alpha, kappa, N, R, T):
    y = np.squeeze(y, axis=2)  # FIXME preferably this should be done outside
    res = (y - np.sum(x * beta.T[:, None, :], axis=2)).T
    v_res = (res @ res.T) / N
    base = eigvalsh(v_res)[:-R].sum() / T
    penalty = np.mean(np.prod(np.linalg.norm(beta[:, :, None] - alpha[:, None, :], axis=0), axis=1)) * kappa
    return base + penalty

In [49]:
def _generate_initial_estimates(y, x, N, T, K, G):
    beta = np.zeros((K, N))
    beta_init = np.zeros_like(beta)

    for i in range(N):
        beta_init[:, i : i + 1] = lstsq(x[i].reshape(T, K), y[i].reshape(T, 1))[0]
    alpha_init = KMeans(n_clusters=G).fit(beta_init.T).cluster_centers_.T

    for j in range(G):
        if np.abs(beta_init.T - alpha_init[:, j]).min() < 1e-2:
            alpha_init[:, j] += 1e-1 * np.sign(alpha_init[:, j])

    mu_init = np.mean(y, axis=1)

    return beta_init, alpha_init, mu_init

In [50]:
R = 3
beta, alpha, _ = _generate_initial_estimates(y, x, N, T, K, G)

# NOTE may have to change to include multiple starting points
# but estimation is valid

obj_value = np.inf
last_obj_value = np.inf
for i in range(50):
    for j in range(G):
        alpha_fixed = alpha.copy()

        def unpack_local(theta):
            beta = theta[: K * N].reshape(K, N)
            alpha = alpha_fixed.copy()
            alpha[:, j : j + 1] = theta[K * N:].reshape(K, 1)
            return beta, alpha

        def obj_local(theta, kappa=0.1):
            beta, alpha = unpack_local(theta)
            return objective_value_without_individual_effects(y, x, beta, alpha, 100, N, R, T)

        def pack_local(beta, alpha):
            return np.concatenate((beta.flatten(), alpha[:, j].flatten()), axis=0)

        if i % 2 == 0:
            minimizer = opt.minimize(
                obj_local, pack_local(beta, alpha), method="BFGS", options={"maxiter": 10}, tol=1e-6
            )
            beta, alpha = unpack_local(minimizer.x)
            obj_value = minimizer.fun
            print(f"BFGS Iteration {i}, Group {j}, Objective Value: {obj_value:.6f}")
        else:
            minimizer = opt.minimize(
                obj_local, pack_local(beta, alpha), method="Nelder-Mead", options={"adaptive": True, "maxiter": 100}, tol=1e-6
            )
            beta, alpha = unpack_local(minimizer.x)
            obj_value = minimizer.fun
            print(f"Nelder-Mead Iteration {i}, Group {j}, Objective Value: {obj_value:.6f}")
    if np.abs(obj_value - last_obj_value) < 1e-4 and i % 2 == 0:
        print("Convergence reached.")
        break
    last_obj_value = obj_value

res = (np.squeeze(y) - np.sum(x * beta.T[:, None, :], axis=2)).T
res_var = (res @ res.T) / N
factors = eigh(res_var).eigenvectors[:, -R:]
factors = factors[:, ::-1]  # Reverse to have descending order

lambdas = np.zeros((R, N))

for i in range(R):
    lambdas[i, :] = factors[:, i].T @ res

beta, alpha, lambdas, factors

BFGS Iteration 0, Group 0, Objective Value: 69.194739
BFGS Iteration 0, Group 1, Objective Value: 23.044047
BFGS Iteration 0, Group 2, Objective Value: 12.323498
Nelder-Mead Iteration 1, Group 0, Objective Value: 11.320218
Nelder-Mead Iteration 1, Group 1, Objective Value: 10.717608
Nelder-Mead Iteration 1, Group 2, Objective Value: 10.438239
BFGS Iteration 2, Group 0, Objective Value: 6.368859
BFGS Iteration 2, Group 1, Objective Value: 5.429720
BFGS Iteration 2, Group 2, Objective Value: 5.058691
Nelder-Mead Iteration 3, Group 0, Objective Value: 4.896532
Nelder-Mead Iteration 3, Group 1, Objective Value: 4.699788
Nelder-Mead Iteration 3, Group 2, Objective Value: 4.699788
BFGS Iteration 4, Group 0, Objective Value: 4.448026
BFGS Iteration 4, Group 1, Objective Value: 4.325433
BFGS Iteration 4, Group 2, Objective Value: 4.284573
Nelder-Mead Iteration 5, Group 0, Objective Value: 4.284573
Nelder-Mead Iteration 5, Group 1, Objective Value: 4.284573
Nelder-Mead Iteration 5, Group 2, Obj

(array([[1.29424459, 1.29424458, 2.89008032, 1.29424461, 2.08900036,
         2.08900032, 1.2942446 , 2.89008033, 1.29424459, 2.08900037,
         2.89008028, 1.2942446 , 2.10305237, 2.08900033, 1.2942446 ,
         2.4945258 , 2.89008034, 2.89008034, 2.08900037, 1.29424462,
         2.89008034, 1.2942446 , 2.08900035, 2.89008033, 1.2942446 ,
         1.29424462, 2.89008031, 2.08900035, 2.08900035, 1.2942446 ],
        [1.46285931, 1.46285929, 2.78542398, 1.46285932, 1.95959681,
         1.95959679, 1.46285931, 2.785424  , 1.46285932, 1.95959685,
         2.78542401, 1.46285933, 1.92387523, 1.95959681, 1.46285932,
         2.40859262, 2.78542399, 2.78542401, 1.95959685, 1.46285932,
         2.785424  , 1.46285931, 1.9595968 , 2.78542401, 1.46285932,
         1.46285932, 2.78542399, 1.95959681, 1.95959681, 1.46285932],
        [1.3361986 , 1.33619859, 2.92839577, 1.3361986 , 2.0535309 ,
         2.05353088, 1.3361986 , 2.92839571, 1.3361986 , 2.05353081,
         2.92839566, 1.33619861,