# Summary of time taken and brier scores for jaxlogit, xlogit, and biogeme
Where the estimation is using draws = 600 (suboptimal but highest without running out of memory in biogeme), and training and test data is separated.

| | jaxlogit-scipy | jaxlogit-jax | xlogit | biogeme |
|---|---|---|---|---|
|Making Model | 33.1s | 22.2s | 18.5s | 4:30 |
|Estimating | 1.6s | 0.2s | 0.0s | 14.3s |
|Brier Score | 0.624247 | 0.624247 | 0.624570 | 0.624163 |

# Setup

In [17]:
import os
os.chdir("/home/evelyn/projects_shared/jaxlogit")

import pandas as pd
import numpy as np
import jax
import pathlib
import xlogit
import sklearn

from jaxlogit.mixed_logit import MixedLogit, ConfigData
from jaxlogit.utils import wide_to_long

import biogeme.biogeme_logging as blog
import biogeme.biogeme as bio
from biogeme import models
from biogeme.expressions import Beta, Draws, log, MonteCarlo, PanelLikelihoodTrajectory
import biogeme.database as db
from biogeme.expressions import Variable

logger = blog.get_screen_logger()
logger.setLevel(blog.INFO)

#  64bit precision
jax.config.update("jax_enable_x64", True)

os.chdir("/home/evelyn/projects_shared/jaxlogit/examples")

# Get the full electricity dataset

Use for jaxlogit and xlogit. Adjustusting n_draws can improve accuracy, but Biogeme cannot handle 700 or more draws with this data set.

In [18]:
df = pd.read_csv(pathlib.Path.cwd() / "electricity_long.csv")
varnames = ['pf', 'cl', 'loc', 'wk', 'tod', 'seas']
n_draws = 600

Reshape the data so it can be passed to test_train_split in a wide format. Additionally, xlogit and jaxlogit require long format while biogeme requires a wide format.

In [19]:
df_long = pd.read_csv(pathlib.Path.cwd() / "electricity_long.csv")
choice_df = df_long.loc[df_long['choice'] == 1, ['id', 'chid', 'alt']]
choice_df = choice_df.rename(columns={'alt': 'choice'})
df_wide = df_long.pivot(index=['id', 'chid'], columns='alt', values=varnames)
df_wide.columns = [f'{var}_{alt}' for var, alt in df_wide.columns]
df_wide = df_wide.reset_index()
df = df_wide.merge(
    choice_df,
    on=['id', 'chid'],
    how='inner',
    validate='one_to_one'
)

df_wide_train, df_wide_test = sklearn.model_selection.train_test_split(df, train_size=0.8)
df_train = wide_to_long(df_wide_train, "chid", [1,2,3,4], "alt", varying=varnames, panels=True)
df_train = df_train.sort_values(['chid', 'alt'])
df_test = wide_to_long(df_wide_test, "chid", [1,2,3,4], "alt", varying=varnames, panels=True)
df_test = df_test.sort_values(['chid', 'alt'])

df_wide_train = df_wide_train.sort_values('chid')
database_train = db.Database('electricity', df_wide_train)
database_train.panel('id')
database_test = db.Database('electricity', df_wide_test)

jaxlogit and xlogit setup:

In [20]:
X_train = df_train[varnames]
y_train = df_train['choice']

ids_train = df_train['chid']
alts_train = df_train['alt']
panels_train = df_train['id']

X_test = df_test[varnames]
y_test = df_test['choice']

ids_test = df_test['chid']
alts_test = df_test['alt']
panels_test = df_test['id']

In [21]:
randvars = {'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'}

model_jax = MixedLogit()
model_x = xlogit.MixedLogit()

init_coeff = None

Biogeme setup:

In [22]:
X = {
    name: {
        j: Variable(f"{name}_{j}")
        for j in [1,2,3,4]
    }
    for name in varnames
}

alt_1 = Beta('alt_1', 0, None, None, 0)
alt_2 = Beta('alt_2', 0, None, None, 0)
alt_3 = Beta('alt_3', 0, None, None, 0)
alt_4 = Beta('alt_4', 0, None, None, 1)

pf_mean = Beta('pf_mean', 0, None, None, 0)
pf_sd = Beta('pf_sd', 1, None, None, 0)
cl_mean = Beta('cl_mean', 0, None, None, 0)
cl_sd = Beta('cl_sd', 1, None, None, 0)
loc_mean = Beta('loc_mean', 0, None, None, 0)
loc_sd = Beta('loc_sd', 1, None, None, 0)
wk_mean = Beta('wk_mean', 0, None, None, 0)
wk_sd = Beta('wk_sd', 1, None, None, 0)
tod_mean = Beta('tod_mean', 0, None, None, 0)
tod_sd = Beta('tod_sd', 1, None, None, 0)
seas_mean = Beta('seas_mean', 0, None, None, 0)
seas_sd = Beta('seas_sd', 1, None, None, 0)

pf_rnd = pf_mean + pf_sd * Draws('pf_rnd', 'NORMAL')
cl_rnd = cl_mean + cl_sd * Draws('cl_rnd', 'NORMAL')
loc_rnd = loc_mean + loc_sd * Draws('loc_rnd', 'NORMAL')
wk_rnd = wk_mean + wk_sd * Draws('wk_rnd', 'NORMAL')
tod_rnd = tod_mean + tod_sd * Draws('tod_rnd', 'NORMAL')
seas_rnd = seas_mean + seas_sd * Draws('seas_rnd', 'NORMAL')

choice = Variable('choice')

V = {
    j: pf_rnd * X['pf'][j] + cl_rnd * X['cl'][j] + loc_rnd * X['loc'][j] + wk_rnd * X['wk'][j] + tod_rnd * X['tod'][j] + seas_rnd * X['seas'][j]
    for j in [1,2,3,4]
}

# Make the models
Jaxlogit:

In [23]:
model_jax_scipy = MixedLogit()
config = ConfigData(
    panels=panels_train,
    n_draws=n_draws,
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=None,
    optim_method="L-BFGS-B-scipy",
)
model_jax_scipy.fit(
    X=X_train,
    y=y_train,
    varnames=varnames,
    ids=ids_train,
    alts=alts_train,
    randvars=randvars,
    config=config
)
display(model_jax_scipy.summary())
init_coeff_scipy = model_jax_scipy.coeff_

    Message: unknown
    Iterations: 78
    Function evaluations: 94
Estimation time= 33.1 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -0.9760174     1.0000000    -0.9760174         0.329    
cl                     -0.2323658     1.0000000    -0.2323658         0.816    
loc                     2.3265031     1.0000000     2.3265031          0.02 *  
wk                      1.6193361     1.0000000     1.6193361         0.105    
tod                    -9.4457607     1.0000000    -9.4457607       6.3e-21 ***
seas                   -9.4985088     1.0000000    -9.4985088      3.85e-21 ***
sd.pf                  -1.4450517     1.0000000    -1.4450517         0.149    
sd.cl                  -0.6657065     1.0000000    -0.6657065         0.506    
sd.loc                  1.5413155

None

In [24]:
model_jax = MixedLogit()
config = ConfigData(
    panels=panels_train,
    n_draws=n_draws,
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=None,
    optim_method="L-BFGS-jax",
)
model_jax.fit(
    X=X_train,
    y=y_train,
    varnames=varnames,
    ids=ids_train,
    alts=alts_train,
    randvars=randvars,
    config=config
)
display(model_jax.summary())
init_coeff_jax = model_jax.coeff_

**** The optimization did not converge after 89 iterations. ****
Convergence not reached. The estimates may not be reliable.


    Message: saddle point reached
    Iterations: 89
    Function evaluations: 134
Estimation time= 22.2 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -0.9760555     1.0000000    -0.9760555         0.329    
cl                     -0.2323721     1.0000000    -0.2323721         0.816    
loc                     2.3266198     1.0000000     2.3266198          0.02 *  
wk                      1.6194020     1.0000000     1.6194020         0.105    
tod                    -9.4461198     1.0000000    -9.4461198      6.28e-21 ***
seas                   -9.4988576     1.0000000    -9.4988576      3.84e-21 ***
sd.pf                  -1.4450489     1.0000000    -1.4450489         0.149    
sd.cl                  -0.6656618     1.0000000    -0.6656618         0.506    
sd.loc             

None

xlogit:

In [26]:
model_x.fit(
    X=X_train,
    y=y_train,
    varnames=varnames,
    ids=ids_train,
    alts=alts_train,
    randvars=randvars,
    panels=panels_train,
    n_draws=n_draws,
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=None,
    optim_method="L-BFGS-B",
)
display(model_x.summary())
init_coeff_x = model_x.coeff_

Optimization terminated successfully.
    Message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH
    Iterations: 69
    Function evaluations: 80
Estimation time= 18.6 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -0.9763775     1.0000000    -0.9763775         0.329    
cl                     -0.2421489     1.0000000    -0.2421489         0.809    
loc                     2.3592633     1.0000000     2.3592633        0.0184 *  
wk                      1.6141594     1.0000000     1.6141594         0.107    
tod                    -9.3777590     1.0000000    -9.3777590      1.18e-20 ***
seas                   -9.5553550     1.0000000    -9.5553550      2.26e-21 ***
sd.pf                  -0.2179004     1.0000000    -0.2179004         0.828    
sd.cl                   0.4063

None

Biogeme:

In [27]:
prob = models.logit(V, None, choice)
logprob = log(MonteCarlo(PanelLikelihoodTrajectory(prob)))

the_biogeme = bio.BIOGEME(
    database_train, logprob, number_of_draws=n_draws, seed=999, generate_yaml=False, generate_html=False
)
the_biogeme.model_name = 'model_b'
results = the_biogeme.estimate()
print(results)

The number of draws (600) is low. The results may not be meaningful. 
The number of draws (600) is low. The results may not be meaningful. 
The number of draws (600) is low. The results may not be meaningful. 
The number of draws (600) is low. The results may not be meaningful. 


Results for model model_b
Nbr of parameters:		12
Sample size:			361
Observations:			3446
Excluded data:			0
Final log likelihood:		-3205.346
Akaike Information Criterion:	6434.692
Bayesian Information Criterion:	6481.358



# Compare parameters:

In [None]:
print("{:>13} {:>13} {:>13} {:>13} {:>13}".format("Estimate", "Jaxlogit-scipy", "Jaxlogit-jax", "xlogit", "biogeme"))
print("-" * 58)
fmt = "{:13} {:13.7f} {:13.7f} {:13.7f} {:13.7f}"
biogeme_values = results.getBetaValues()
coeff_names = {'pf': 'pf_mean', 'sd.pf': 'pf_sd', 'cl': 'cl_mean', 'sd.cl': 'cl_sd', 'loc': 'loc_mean', 'sd.loc': 'loc_sd', 'wk': 'wk_mean', 'sd.wk': 'wk_sd', 'tod': 'tod_mean', 'sd.tod': 'tod_sd', 'seas': 'seas_mean', 'sd.seas': 'seas_sd'}
for i in range(len(model_jax.coeff_)):
    name = model_jax.coeff_names[i]
    print(fmt.format(name[:13], 
                     model_jax_scipy.coeff_[i], 
                     model_jax.coeff_[i], 
                     model_x.coeff_[i],
                     biogeme_values[coeff_names[name]]))
print("-" * 58)

     Estimate Jaxlogit-scipy  Jaxlogit-jax        xlogit       biogeme
----------------------------------------------------------


AttributeError: 'RawEstimationResults' object has no attribute 'data'

# Predict
jaxlogit:

In [30]:
model = model_jax_scipy 
config = ConfigData(
    panels=panels_test,
    n_draws=n_draws,
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=None,
    optim_method="L-BFGS-B-scipy",
)
config.init_coeff = init_coeff_scipy
prob_j_scipy = model.predict(X_test, varnames, alts_test, ids_test, randvars, config)

In [31]:
model = model_jax
config = ConfigData(
    panels=panels_test,
    n_draws=n_draws,
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=None,
    optim_method="L-BFGS-jax",
)
config.init_coeff = init_coeff_jax
prob_j_jax = model.predict(X_test, varnames, alts_test, ids_test, randvars, config)

xlogit:

In [32]:
_, prob_xx = model_x.predict(X_test, varnames, alts_test, ids_test, isvars=None, panels=panels_test, n_draws=n_draws, return_proba=True)

Biogeme:

In [33]:
P = {
    j: MonteCarlo(models.logit(V, None, j))
    for j in [1, 2, 3, 4]
}

simulate = {
    f'Prob_alt{j}': P[j]
    for j in [1, 2, 3, 4]
}

biogeme_sim = bio.BIOGEME(database_test, simulate)
biogeme_sim.model_name = 'per_choice_probs'

probs = biogeme_sim.simulate(results.get_beta_values())

Compute the brier score:

In [34]:
print("{:>9} {:>9} {:>9} {:>9}".format("Jaxlogit-scipy", "Jaxlogit-jax", "xlogit", "Biogeme"))
print("-" * 31)
fmt = "{:9f} {:9f} {:9f} {:9f}"
print(fmt.format(sklearn.metrics.brier_score_loss(np.reshape(y_test, (prob_j_scipy.shape[0], -1)), prob_j_scipy),
                 sklearn.metrics.brier_score_loss(np.reshape(y_test, (prob_j_jax.shape[0], -1)), prob_j_jax),
                 sklearn.metrics.brier_score_loss(np.reshape(y_test, (prob_xx.shape[0], -1)), prob_xx),
                 sklearn.metrics.brier_score_loss(df_wide_test['choice'], probs)))
print("-" * 31)

Jaxlogit-scipy Jaxlogit-jax    xlogit   Biogeme
-------------------------------
 0.624247  0.624248  0.624570  0.624163
-------------------------------
