# Setup

In [28]:
import pandas as pd
import numpy as np
import jax
import pathlib
import xlogit
import sklearn

from jaxlogit.mixed_logit import MixedLogit, ConfigData

In [29]:
#  64bit precision
jax.config.update("jax_enable_x64", True)

# Get the full electricity dataset

In [30]:
df = pd.read_csv(pathlib.Path.cwd() / "electricity_long.csv")
varnames = ['pf', 'cl', 'loc', 'wk', 'tod', 'seas']
n_draws = 600

In [42]:
X = df[varnames]
y = df['choice']

ids = df['chid']
alts = df['alt']
panels = df['id']
randvars = {'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'}

model_jax = MixedLogit()
model_x = xlogit.MixedLogit()

config = ConfigData(
    panels=panels,
    n_draws=n_draws,
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=None,
    optim_method="L-BFGS-B",
)
init_coeff = None

In [50]:
y = df['choice']

# Make the model in jaxlogit

In [51]:
model_jax.fit(
    X=X,
    y=y,
    varnames=varnames,
    ids=ids,
    alts=alts,
    randvars=randvars,
    config=config
)
display(model_jax.summary())
init_coeff_jax = model_jax.coeff_

    Message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH
    Iterations: 80
    Function evaluations: 97
Estimation time= 39.0 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -0.9972244     1.0000000    -0.9972244         0.319    
cl                     -0.2196763     1.0000000    -0.2196763         0.826    
loc                     2.2901926     1.0000000     2.2901926        0.0221 *  
wk                      1.6943196     1.0000000     1.6943196        0.0903 .  
tod                    -9.6753913     1.0000000    -9.6753913       6.4e-22 ***
seas                   -9.6962087     1.0000000    -9.6962087      5.24e-22 ***
sd.pf                  -1.3984445     1.0000000    -1.3984445         0.162    
sd.cl                  -0.6750223     1.0000000    -0.6750223       

None

# Make the model in xlogit

In [43]:
model_x.fit(
    X=X,
    y=y,
    varnames=varnames,
    ids=ids,
    alts=alts,
    randvars=randvars,
    panels=panels,
    n_draws=n_draws,
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=None,
    optim_method="L-BFGS-B",
)
display(model_x.summary())
init_coeff_x = model_x.coeff_

Optimization terminated successfully.
    Message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH
    Iterations: 54
    Function evaluations: 60
Estimation time= 15.9 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -0.9971860     1.0000000    -0.9971860         0.319    
cl                     -0.2196661     1.0000000    -0.2196661         0.826    
loc                     2.2902861     1.0000000     2.2902861        0.0221 *  
wk                      1.6943008     1.0000000     1.6943008        0.0903 .  
tod                    -9.6751588     1.0000000    -9.6751588      6.42e-22 ***
seas                   -9.6960039     1.0000000    -9.6960039      5.25e-22 ***
sd.pf                   0.2207141     1.0000000     0.2207141         0.825    
sd.cl                   0.4115

None

# Predict from the model using jaxlogit

In [57]:
model = model_jax 
init_coeff = init_coeff_x # choose which params to use

In [58]:
config = ConfigData(
    panels=panels,
    n_draws=n_draws,
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=None,
    optim_method="L-BFGS-B",
)

In [59]:
config.init_coeff = init_coeff
print(config.init_coeff)

[-0.99718604 -0.21966612  2.29028615  1.69430084 -9.67515876 -9.69600385
  0.22071409  0.41156706  1.7840999   1.22968056  2.27591004  1.48629077]


In [60]:
prob_xj = model.predict(X, varnames, alts, ids, randvars, config)

# Predict from the model using xlogit

In [54]:
model = model_x 
init_coeff = init_coeff_jax # choose which params to use

In [55]:
_, prob_jx = model.predict(X, varnames, alts, ids, isvars=None, panels=panels, n_draws=n_draws, return_proba=True)
# X, varnames, alts, ids, panels=panels, n_draws=n_draws, return_proba=True

# Test the results

In [None]:
y = np.reshape(y, (prob.shape[0], -1))

total_counted = 0
correct = 0
for i in range(prob.shape[0]):
    y_index = np.argmax(y[i])
    if prob[i][y_index] == np.max(prob[i]):
        correct += 1
    total_counted += 1
print(f"percentage correct = {correct/total_counted}")

Compare the probabilities and results

Format:
[prob prob ... prob] : [which chosen]

In [None]:
for i in range(prob.shape[0]):
    print(f"{prob[i]} : {y[i]}")

In [37]:
y = np.reshape(y, (prob.shape[0], -1))
print(sklearn.metrics.brier_score_loss(y, prob))

0.6275143823422479


In [47]:
y = np.reshape(y, (prob_xx.shape[0], -1))
print(sklearn.metrics.brier_score_loss(y, prob_xx))

0.6275143437134769


In [56]:
y = np.reshape(y, (prob_jx.shape[0], -1))
print(sklearn.metrics.brier_score_loss(y, prob_jx))

0.6275143437134769


In [None]:
y = np.reshape(y, (prob_xj.shape[0], -1))
print(sklearn.metrics.brier_score_loss(y, prob_xj))

# Don't forget to clear output when done