In [1]:
import numpy as np
from scipy.linalg import lapack
from scipy.special import expit
from scipy.stats import chi2
import pandas as pd
from scipy.special import expit
from firthlogist import FirthLogisticRegression

In [132]:
def _get_XW(X, preds, mask=None):
    # mask is 1-indexed because 0 == None
    rootW = np.sqrt(preds * (1 - preds))
    XW = rootW[..., np.newaxis] * X

    # is this equivalent??
    # https://github.com/georgheinze/logistf/blob/master/src/logistf.c#L150-L159
    if mask:
        XW[:, mask - 1] = 0
    return XW

def _get_aug_XW(X, preds, tau, h_diag):
    # mask is 1-indexed because 0 == None
    rootW = np.sqrt(preds * (1 - preds) * (1 + 2 * tau * h_diag))
    XW = rootW[..., np.newaxis] * X
    
    
    return XW

def _hat_diag(XW):
    # Get diagonal elements of the hat matrix
    # Q = np.linalg.qr(XW, mode="reduced")[0]
    qr, tau, _, _ = lapack.dgeqrf(XW)
    Q, _, _ = lapack.dorgqr(qr, tau)
    hat = np.einsum("ij,ij->i", Q, Q)
    return hat

def _loglikelihood(X, y, preds):
    # penalized log-likelihood
    XW = _get_XW(X, preds)
    fisher_info_mtx = XW.T @ XW
    penalty = 0.5 * np.log(np.linalg.det(fisher_info_mtx))
    return -1 * (np.sum(y * np.log(preds) + (1 - y) * np.log(1 - preds)) + penalty)

def _get_preds(X, coef):
    return expit(X @ coef)

In [414]:
sex2 = pd.read_csv('tests/sex2.csv')
diabetes = pd.read_csv('tests/diabetes.csv')
# diabetes.insert(0, 'intercept', np.repeat(1, diabetes.shape[0]))
# X = sex2.iloc[:, 1:].values
# y = sex2['case'].values
X = diabetes.iloc[:, :-1].values
y = diabetes['Outcome'].values
firth = FirthLogisticRegression()
firth.fit(X, y)

FirthLogisticRegression()

In [415]:
X = np.hstack((np.ones((X.shape[0], 1)), X))
coef = np.array([firth.intercept_, *firth.coef_])

In [416]:
k = X.shape[1]
n = X.shape[0]
which = -1
iSel = 1
tol = firth.tol
tau = 0.5
alpha = 0.05
max_halfstep = 1000
max_stepsize = 5
max_iter = 25

In [360]:
# Initialization
LL0 = firth.loglik_ - (chi2.ppf(1 - alpha, 1) / 2)
preds = _get_preds(X, coef)
loglik = -_loglikelihood(X, y, preds)

# First Iteration

In [361]:
# calc XW^1/2
xw2 = _get_XW(X, preds).T
fisher_info_matrix = xw2 @ xw2.T
log_fisher_det = np.log(np.linalg.det(fisher_info_matrix))
fisher_inv = np.linalg.inv(fisher_info_matrix)

In [362]:
# Calc hat diag and loglikelihood
tmpNxK = xw2.T @ fisher_inv
# This looks correct
h_diag = np.diag(tmpNxK @ xw2)
h_diag.shape

(239,)

In [363]:
# Get the augmented data using hat diag (lines 753-759 in C code)
xw2_aug = _get_aug_XW(X, preds, tau, h_diag).T
fisher_aug = xw2_aug @ xw2_aug.T
fisher_aug_det = np.log(np.linalg.det(fisher_aug))
fisher_aug_inv = np.linalg.inv(fisher_aug)

In [364]:
# calc the new weights
w = (y - preds) + 2 * tau * h_diag * (0.5 - preds)
w[:5]

array([0.67250885, 0.07458449, 0.07458449, 0.07458449, 0.07458449])

In [365]:
# Calc Ustar and lambda values
Ustar = X.T @ w
print(Ustar)
fisher_inv = -fisher_aug_inv.copy()
tmpKx1 = Ustar.T @ fisher_inv
tmp1x1 = tmpKx1.T @ Ustar
tmp1x1

[ 1.99780046e-06  1.94686835e-06 -1.47840235e-07  2.15588817e-06
  2.28431037e-06  2.07192716e-06  1.81926203e-06]


-7.551333525013887e-12

In [366]:
fisher_aug_inv[iSel, iSel]

0.1721437593189589

In [367]:
under_root = -2 * ((LL0 - loglik) + 0.5 * tmp1x1) / fisher_aug_inv[iSel, iSel]
print(under_root)
lambda_ = 0 if under_root < 0 else which * np.sqrt(under_root)
print(lambda_)
Ustar[iSel] = Ustar[iSel] + lambda_
step_size = fisher_aug_inv.T @ Ustar
if max_halfstep >= 0:
    mx = np.max(np.abs(step_size)) / max_halfstep
    if mx > 1:
        step_size = step_size / mx
new_coef = coef + step_size

22.31541146713301
-4.723919079232096


In [368]:
step_size

array([ 0.12738916, -0.81319301, -0.0096957 ,  0.10875095, -0.16123714,
       -0.00415691,  0.33766839])

In [358]:
loglik_old = loglik.copy()
for halfs in range(1, max_halfstep + 1):
    new_preds = _get_preds(X, new_coef)
    xw2 = _get_XW(X, new_preds).T
    loglik = -_loglikelihood(X, y, new_preds)
    fisher_info_matrix = xw2 @ xw2.T
    fisher_det = np.linalg.det(fisher_info_matrix)
    fisher_inv = np.linalg.inv(fisher_info_matrix)
    # Calc hat diag and loglikelihood
    tmpNxK = xw2.T @ fisher_inv
    h_diag = np.diagonal(tmpNxK @ xw2)
    loglik = -_loglikelihood(X, y, new_preds)
    # Get augmented data
    xw2_aug = _get_aug_XW(X, new_preds, tau, h_diag).T
    fisher_aug = xw2_aug @ xw2_aug.T
    fisher_aug_det = np.linalg.det(fisher_aug)
    fisher_aug = np.linalg.inv(fisher_aug)
    if (abs(loglik - LL0) < abs(loglik_old - LL0)) and (loglik > LL0):
        print(halfs, 'broken out of loop')
        break
    step_size /= 2
    new_coef -= step_size
if abs(loglike - LL0) <= tol:
    print(new_coef[iSel])

6 broken out of loop


In [359]:
step_size

array([ 0.00398091, -0.02541228, -0.00030299,  0.00339847, -0.00503866,
       -0.0001299 ,  0.01055208])

# Iteration To Max Iters

In [428]:
sex2 = pd.read_csv('tests/sex2.csv')
diabetes = pd.read_csv('tests/diabetes.csv')
# diabetes.insert(0, 'intercept', np.repeat(1, diabetes.shape[0]))
# X = sex2.iloc[:, 1:].values
# y = sex2['case'].values
X = diabetes.iloc[:, :-1].values
y = diabetes['Outcome'].values
firth = FirthLogisticRegression()
firth.fit(X, y)
X = np.hstack((np.ones((X.shape[0], 1)), X))

In [429]:

coef = np.array([firth.intercept_, *firth.coef_])

In [430]:
k = X.shape[1]
n = X.shape[0]
# which = -1 -> lower bounds, which = 1 -> upper bounds
which = -1
tol = firth.tol
tau = 0.5
alpha = 0.05
max_halfstep = 1000
max_stepsize = 5
max_iter = 25

In [431]:
lower_bds = []
for i in range(k):
    iSel = i
    LL0 = firth.loglik_ - (chi2.ppf(1 - alpha, 1) / 2)
    coef = np.array([firth.intercept_, *firth.coef_])
    for iter_ in range(1, max_iter + 1):
        # Initialization

        preds = _get_preds(X, coef)
        loglik = -_loglikelihood(X, y, preds)
        # print(f'{iter_=}')
        # calc XW^1/2
        xw2 = _get_XW(X, preds).T
        fisher_info_matrix = xw2 @ xw2.T
        log_fisher_det = np.log(np.linalg.det(fisher_info_matrix))
        fisher_inv = np.linalg.inv(fisher_info_matrix)

        # Calc hat diag and loglikelihood
        tmpNxK = xw2.T @ fisher_inv
        # This looks correct
        h_diag = np.diag(tmpNxK @ xw2)

        # Get the augmented data using hat diag (lines 753-759 in C code)
        xw2_aug = _get_aug_XW(X, preds, tau, h_diag).T
        fisher_aug = xw2_aug @ xw2_aug.T
        fisher_aug_det = np.log(np.linalg.det(fisher_aug))
        fisher_aug_inv = np.linalg.inv(fisher_aug)

        # calc the new weights
        w = (y - preds) + 2 * tau * h_diag * (0.5 - preds)
        Ustar = X.T @ w
        # print(f'{Ustar=}')
        fisher_inv = -fisher_aug_inv.copy()
        tmpKx1 = Ustar.T @ fisher_inv
        tmp1x1 = tmpKx1.T @ Ustar
        # print(f'{tmp1x1=}')
        # print(f'fisher_augmented = {fisher_aug_inv[iSel, iSel]}')
        under_root = -2 * ((LL0 - loglik) + 0.5 * tmp1x1) / fisher_aug_inv[iSel, iSel]
        # print(f'{under_root=}')
        lambda_ = 0 if under_root < 0 else which * np.sqrt(under_root)
        # print(f'{lambda_=}')
        Ustar[iSel] = Ustar[iSel] + lambda_
        step_size = fisher_aug_inv.T @ Ustar
        if max_halfstep >= 0:
            mx = np.max(np.abs(step_size)) / max_halfstep
            if mx > 1:
                step_size = step_size / mx
        # print(f'step_size={[x for x in step_size]}\n')
        coef = coef + step_size
        # Iterate over halfsteps
        loglik_old = loglik.copy()
        for halfs in range(1, max_halfstep + 1):
            preds = _get_preds(X, coef)
            xw2 = _get_XW(X, preds).T
            loglik = -_loglikelihood(X, y, preds)
            fisher_info_matrix = xw2 @ xw2.T
            fisher_det = np.linalg.det(fisher_info_matrix)
            fisher_inv = np.linalg.inv(fisher_info_matrix)
            # Calc hat diag and loglikelihood
            tmpNxK = xw2.T @ fisher_inv
            h_diag = np.diagonal(tmpNxK @ xw2)
            loglik = -_loglikelihood(X, y, preds)
            # Get augmented data
            xw2_aug = _get_aug_XW(X, preds, tau, h_diag).T
            fisher_aug = xw2_aug @ xw2_aug.T
            fisher_aug_det = np.linalg.det(fisher_aug)
            fisher_aug = np.linalg.inv(fisher_aug)
            if (abs(loglik - LL0) < abs(loglik_old - LL0)) and (loglik > LL0):
                break
            step_size /= 2
            coef -= step_size
        if abs(loglik - LL0) <= tol:
            lower_bds.append(coef[iSel])
            break

In [424]:
upper_bds

[-6.923864735572734,
 0.1845867771073193,
 0.041972293104437036,
 -0.0029846187016573204,
 0.014074318544182838,
 0.0005828123655084136,
 0.11804495417460169,
 1.5181403571733085,
 0.032864335783192394]

In [432]:
lower_bds

[-9.706470179471,
 0.05981888758405426,
 0.027558848402368556,
 -0.023320855260184185,
 -0.012785547543677646,
 -0.0029204525872442745,
 0.05937009703858033,
 0.35283695211824373,
 -0.003424204408048102]