In [120]:
import pandas as pd
import numpy as np
import scipy.optimize as opt

from statsmodels.api import add_constant
from numpy.linalg import lstsq, eigh, eigvalsh
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from scipy.optimize import minimize
from matplotlib import pyplot as plt

from skglm import GeneralizedLinearEstimator
from skglm.penalties import SCAD

import jax, jax.numpy as jnp
from jaxopt import LBFGS

from numba import njit, jit
import optax

In [11]:
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

In [None]:
G = 3
R = 3 # NOTE number of factors
GF = np.array((1, 1, 1))

In [25]:
# NOTE this objective value is always without individual effects
def objective_value_without_individual_effects(y, x, beta, alpha, kappa, N, R):
    y = np.squeeze(y, axis=2) # FIXME preferably this should be done outside
    # base = ((y - np.sum(x * beta.T[:, None, :], axis=2)) ** 2).mean()
    res = (y - np.sum(x * beta.T[:, None, :], axis=2))
    display(res)
    raise NotImplementedError("This function is not implemented yet.")
    penalty = np.mean(np.prod(np.linalg.norm(beta[:, :, None] - alpha[:, None, :], axis=0), axis=1)) * kappa
    return base + penalty

In [14]:
# NOTE for some reason this is slower than the non-jnp version
@jax.jit
def jnp_objective_value(y, x, beta, alpha, mu, kappa):
    base = ((y - jnp.sum(x * beta.T[:, None, :], axis=2) - mu) ** 2).mean()
    penalty = jnp.mean(jnp.prod(jnp.linalg.norm(beta[:, :, None] - alpha[:, None, :], axis=0), axis=1)) * kappa
    return base + penalty

In [15]:
def unpack(theta):
    beta = theta[:K * N].reshape(K, N)
    mu = theta[K * N:K * N + N].reshape(N, 1)
    alpha = theta[K * N + N:].reshape(K, G)
    return beta, mu, alpha

def pack(beta, mu, alpha):
    return np.concatenate((beta.flatten(), mu.flatten(), alpha.flatten()), axis=0)

In [16]:
# @njit
def obj(theta, kappa=1):
    beta, mu, alpha = unpack(theta)
    return objective_value(y, x, beta, alpha, mu, kappa)

In [17]:
# Get Initial
# TODO think of better initialization technique
def _generate_initial_estimates(y, x, N, T, K, G):
    beta_init = np.zeros_like(beta)

    for i in range(N):
        beta_init[:, i:i+1] = lstsq(x[i].reshape(T, K), y[i].reshape(T, 1))[0]
    alpha_init = KMeans(n_clusters=G).fit(beta_init.T).cluster_centers_.T

    for j in range(G):
        if (np.abs(beta_init.T - alpha_init[:, j]).min() < 1e-2):
            alpha_init[:, j] += 1e-1 * np.sign(alpha_init[:, j])

    mu_init = np.mean(y, axis=1)

    return beta_init, alpha_init, mu_init

In [134]:
import pickle

example = pickle.load(open("dgp3-example-7.pkl", "rb"))

In [140]:
x_example = example[0]
y_example = np.reshape(example[1], (30, 100, 1))
g_example = example[2]
alpha_example = example[3]
beta_example = example[4]
beta_example

array([[1, 1, 1],
       [2, 2, 2],
       [3, 3, 3]])

In [141]:
y_example.shape

(30, 100, 1)

In [142]:
x = x_example
y = y_example
N, T, K = x.shape
G = 3

In [143]:
y = np.float32(y)
x = np.float32(x)

In [126]:
from joblib import Memory

memory = Memory(location=None, verbose=0)

In [144]:
# NOTE this objective value is always without individual effects
def objective_value_without_individual_effects(y, x, beta, alpha, kappa, N, R, T):
    y = np.squeeze(y, axis=2)  # FIXME preferably this should be done outside
    res = (y - np.sum(x * beta.T[:, None, :], axis=2)).T
    v_res = (res @ res.T) / N
    base = eigvalsh(v_res)[:-R].sum() / T
    penalty = np.mean(np.prod(np.linalg.norm(beta[:, :, None] - alpha[:, None, :], axis=0), axis=1)) * kappa
    return base + penalty

In [145]:
def _generate_initial_estimates(y, x, N, T, K, G):
    beta = np.zeros((K, N))
    beta_init = np.zeros_like(beta)

    for i in range(N):
        beta_init[:, i : i + 1] = lstsq(x[i].reshape(T, K), y[i].reshape(T, 1))[0]
    alpha_init = KMeans(n_clusters=G).fit(beta_init.T).cluster_centers_.T

    for j in range(G):
        if np.abs(beta_init.T - alpha_init[:, j]).min() < 1e-2:
            alpha_init[:, j] += 1e-1 * np.sign(alpha_init[:, j])

    mu_init = np.mean(y, axis=1)

    return beta_init, alpha_init, mu_init

In [None]:
R = 3
beta, alpha, _ = _generate_initial_estimates(y, x, N, T, K, G)

# NOTE may have to change to include multiple starting points
# but estimation is valid

obj_value = np.inf
last_obj_value = np.inf
for i in range(50):
    for j in range(G):
        alpha_fixed = alpha.copy()

        def unpack_local(theta):
            beta = theta[: K * N].reshape(K, N)
            alpha = alpha_fixed.copy()
            alpha[:, j : j + 1] = theta[K * N:].reshape(K, 1)
            return beta, alpha

        def obj_local(theta, kappa=0.1):
            beta, alpha = unpack_local(theta)
            return objective_value_without_individual_effects(y, x, beta, alpha, 100, N, R, T)

        def pack_local(beta, alpha):
            return np.concatenate((beta.flatten(), alpha[:, j].flatten()), axis=0)

        if i % 2 == 0:
            minimizer = opt.minimize(
                obj_local, pack_local(beta, alpha), method="BFGS", options={"maxiter": 10}, tol=1e-6
            )
            beta, alpha = unpack_local(minimizer.x)
            obj_value = minimizer.fun
            print(f"BFGS Iteration {i}, Group {j}, Objective Value: {obj_value:.6f}")
        else:
            minimizer = opt.minimize(
                obj_local, pack_local(beta, alpha), method="Nelder-Mead", options={"adaptive": True, "maxiter": 100}, tol=1e-6
            )
            beta, alpha = unpack_local(minimizer.x)
            obj_value = minimizer.fun
            print(f"Nelder-Mead Iteration {i}, Group {j}, Objective Value: {obj_value:.6f}")
    # if np.abs(obj_value - last_obj_value) < 1e-6:
    #     print("Convergence reached.")
    #     break
    last_obj_value = obj_value

res = (np.squeeze(y) - np.sum(x * beta.T[:, None, :], axis=2)).T
res_var = (res @ res.T) / N
factors = eigh(res_var).eigenvectors[:, -R:]
factors = factors[:, ::-1]  # Reverse to have descending order

lambdas = np.zeros((R, N))

for i in range(R):
    lambdas[i, :] = factors[:, i].T @ res

beta, alpha, lambdas, factors

BFGS Iteration 0, Group 0, Objective Value: 115.321220
BFGS Iteration 0, Group 1, Objective Value: 55.404387
BFGS Iteration 0, Group 2, Objective Value: 21.992048
Nelder-Mead Iteration 1, Group 0, Objective Value: 20.356378
Nelder-Mead Iteration 1, Group 1, Objective Value: 18.737714
Nelder-Mead Iteration 1, Group 2, Objective Value: 18.643584
BFGS Iteration 2, Group 0, Objective Value: 10.931689
BFGS Iteration 2, Group 1, Objective Value: 7.143474
BFGS Iteration 2, Group 2, Objective Value: 5.035192
Nelder-Mead Iteration 3, Group 0, Objective Value: 5.035192
Nelder-Mead Iteration 3, Group 1, Objective Value: 5.035192
Nelder-Mead Iteration 3, Group 2, Objective Value: 5.035192
BFGS Iteration 4, Group 0, Objective Value: 4.149728
BFGS Iteration 4, Group 1, Objective Value: 3.480070
BFGS Iteration 4, Group 2, Objective Value: 3.243459
Nelder-Mead Iteration 5, Group 0, Objective Value: 3.243459
Nelder-Mead Iteration 5, Group 1, Objective Value: 3.243459
Nelder-Mead Iteration 5, Group 2, O

(array([[3.00540659, 3.00540659, 1.9694055 , 3.00540659, 1.40315712,
         1.40315713, 3.00540659, 1.40315713, 1.96940549, 3.00540659,
         1.9694055 , 1.9694055 , 1.40315713, 3.00540659, 1.9694055 ,
         1.40315713, 1.40315713, 1.9694055 , 1.40315713, 3.00540659,
         1.9694055 , 1.40315713, 1.9694055 , 3.00540659, 1.62632222,
         3.00540659, 1.96940549, 1.9694055 , 3.00540659, 3.00540659],
        [2.92607291, 2.9260729 , 1.87986349, 2.92607291, 1.38168267,
         1.38168268, 2.92607291, 1.38168268, 1.87986348, 2.92607291,
         1.87986349, 1.87986349, 1.38168268, 2.92607291, 1.87986349,
         1.38168268, 1.38168268, 1.87986349, 1.38168268, 2.92607291,
         1.87986349, 1.38168268, 1.87986349, 2.92607291, 1.57285883,
         2.92607291, 1.87986348, 1.8798635 , 2.92607291, 2.92607291],
        [2.95591952, 2.95591952, 2.18545943, 2.95591952, 1.38111508,
         1.38111508, 2.95591952, 1.38111508, 2.18545942, 2.95591952,
         2.18545943, 2.18545943,