# Summary of time taken and brier scores for jaxlogit, xlogit, and biogeme
Where the estimation is using draws = 600 (suboptimal but highest without running out of memory in biogeme), and training and test data is separated.

| | jaxlogit-scipy | jaxlogit-jax | xlogit | biogeme |
|---|---|---|---|---|
|Making Model | 33.1s | 22.2s | 18.5s | 4:30 |
|Estimating | 1.6s | 0.2s | 0.0s | 14.3s |
|Brier Score | 0.624247 | 0.624247 | 0.624570 | 0.624163 |

# Setup

In [32]:
import pandas as pd
import numpy as np
import jax
import pathlib
import xlogit
import sklearn

import os
os.chdir("/home/evelyn/projects_shared/jaxlogit")

from jaxlogit.mixed_logit import MixedLogit, ConfigData
from jaxlogit.utils import wide_to_long

os.chdir("/home/evelyn/projects_shared/jaxlogit/examples")

import biogeme.biogeme_logging as blog
import biogeme.biogeme as bio
from biogeme import models
from biogeme.expressions import Beta, Draws, log, MonteCarlo, PanelLikelihoodTrajectory
import biogeme.database as db
from biogeme.expressions import Variable

logger = blog.get_screen_logger()
logger.setLevel(blog.INFO)

#  64bit precision
jax.config.update("jax_enable_x64", True)

In [33]:
df_wide = pd.read_table("http://transp-or.epfl.ch/data/swissmetro.dat", sep='\t')

# Keep only observations for commute and business purposes that contain known choices
df_wide = df_wide[(df_wide['PURPOSE'].isin([1, 3]) & (df_wide['CHOICE'] != 0))]

df_wide['custom_id'] = np.arange(len(df_wide))  # Add unique identifier
df_wide['CHOICE'] = df_wide['CHOICE'].map({1: 'TRAIN', 2:'SM', 3: 'CAR'})
df_wide
df_wide_test, df_wide_train = sklearn.model_selection.train_test_split(df_wide, train_size=0.8)

Reshape the data so it can be passed to test_train_split in a wide format. Additionally, xlogit and jaxlogit require long format while biogeme requires a wide format.

In [34]:
from jaxlogit.utils import wide_to_long

df_train = wide_to_long(df_wide_train, id_col='custom_id', alt_name='alt', sep='_',
                  alt_list=['TRAIN', 'SM', 'CAR'], empty_val=0,
                  varying=['TT', 'CO', 'HE', 'AV', 'SEATS'], alt_is_prefix=True)
df_test = wide_to_long(df_wide_test, id_col='custom_id', alt_name='alt', sep='_',
                  alt_list=['TRAIN', 'SM', 'CAR'], empty_val=0,
                  varying=['TT', 'CO', 'HE', 'AV', 'SEATS'], alt_is_prefix=True)
df_train

Unnamed: 0,custom_id,alt,TT,CO,HE,AV,SEATS,GROUP,SURVEY,SP,...,TICKET,WHO,LUGGAGE,AGE,MALE,INCOME,GA,ORIGIN,DEST,CHOICE
0,1248,TRAIN,104,48,120,1,0,2,0,1,...,1,1,1,2,1,2,0,2,25,SM
1,1248,SM,45,62,10,1,0,2,0,1,...,1,1,1,2,1,2,0,2,25,SM
2,1248,CAR,88,65,0,1,0,2,0,1,...,1,1,1,2,1,2,0,2,25,SM
3,4924,TRAIN,224,98,120,1,0,3,1,1,...,1,2,0,4,1,3,0,22,17,SM
4,4924,SM,80,112,30,1,0,3,1,1,...,1,2,0,4,1,3,0,22,17,SM
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4057,4696,SM,171,109,20,1,0,3,1,1,...,1,2,0,3,1,3,0,2,17,SM
4058,4696,CAR,160,123,0,1,0,3,1,1,...,1,2,0,3,1,3,0,2,17,SM
4059,547,TRAIN,185,17,120,1,0,2,0,1,...,10,1,1,3,0,1,0,10,17,SM
4060,547,SM,84,22,20,1,0,2,0,1,...,10,1,1,3,0,1,0,10,17,SM


jaxlogit and xlogit setup:

In [35]:
df_train['ASC_TRAIN'] = np.where(df_train['alt'] == 'TRAIN', 1, 0)
df_train['ASC_CAR'] = np.where(df_train['alt'] == 'CAR', 1, 0)
df_train['ASC_SM'] = np.where(df_train['alt'] == 'SM', 1, 0)

df_train['TT'] = df_train['TT'] / 100.0
df_train['CO'] = df_train['CO'] / 100.0

df_train.loc[(df_train['GA'] == 1) & (df_train['alt'].isin(['TRAIN', 'SM'])), 'CO'] = 0  # Cost zero for pass holders

df_test['ASC_TRAIN'] = np.where(df_test['alt'] == 'TRAIN', 1, 0)
df_test['ASC_CAR'] = np.where(df_test['alt'] == 'CAR', 1, 0)
df_test['ASC_SM'] = np.where(df_test['alt'] == 'SM', 1, 0)

df_test['TT'] = df_test['TT'] / 100.0
df_test['CO'] = df_test['CO'] / 100.0

df_test.loc[(df_test['GA'] == 1) & (df_test['alt'].isin(['TRAIN', 'SM'])), 'CO'] = 0  # Cost zero for pass holders

In [36]:
varnames = ['ASC_SM', 'ASC_CAR', 'ASC_TRAIN', 'TT', 'CO']

randvars = {'CO': 'n', 'TT': 'n'}  

fixedvars = {'ASC_TRAIN': 0.0}

do_panel = True

model_jax = MixedLogit()

config = ConfigData(
    avail=df_train['AV'],
    panels=None if do_panel is False else df_train["ID"],
    n_draws=1000,
    init_coeff=None,
    include_correlations=False,
    optim_method='L-BFGS-jax',
    skip_std_errs=False,
    force_positive_chol_diag=False,  # not using softplus for std devs here for comparability with biogeme
)

res = model_jax.fit(
    X=df_train[varnames],
    y=df_train['CHOICE'],
    varnames=varnames,
    alts=df_train['alt'],
    ids=df_train['custom_id'],
    randvars=randvars,
    config=config
)
model_jax.summary()
init_coeff_jax = model_jax.coeff_

model_scipy = MixedLogit()
config.optim_method = "L-BFGS-scipy"
res = model_scipy.fit(
    X=df_train[varnames],
    y=df_train['CHOICE'],
    varnames=varnames,
    alts=df_train['alt'],
    ids=df_train['custom_id'],
    randvars=randvars,
    config=config
)
model_scipy.summary()
init_coeff_scipy = model_scipy.coeff_

**** The optimization did not converge after 19 iterations. ****
Convergence not reached. The estimates may not be reliable.


    Message: max line search iters reached
    Iterations: 19
    Function evaluations: 27
Estimation time= 5.5 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
ASC_SM                  0.0753688           nan           nan           nan    
ASC_CAR                 0.4149755           nan           nan           nan    
ASC_TRAIN              -0.1903443           nan           nan           nan    
TT                     -4.2349197     0.5283379    -8.0155522      2.35e-15 ***
CO                     -3.6027674     0.4928050    -7.3107368      4.53e-13 ***
sd.TT                   3.9371494     0.5711891     6.8929001      8.36e-12 ***
sd.CO                   3.3806978     0.5176755     6.5305342      9.25e-11 ***
---------------------------------------------------------------------------
Significance:  0

In [37]:
df_train = df_train.sort_values(['custom_id', 'alt'])
df_test = df_test.sort_values(['custom_id', 'alt'])

In [38]:
model_x = xlogit.MixedLogit()
varnames=['ASC_CAR', 'ASC_TRAIN', 'ASC_SM', 'CO', 'TT']
model_x.fit(X=df_train[varnames], y=df_train['CHOICE'], varnames=varnames,
          alts=df_train['alt'], ids=df_train['custom_id'], avail=df_train['AV'],
          panels=df_train["ID"], randvars={'TT': 'n'}, n_draws=1500,
          optim_method='L-BFGS-B')
model_x.summary()
init_coeff_x = model_x.coeff_

Optimization terminated successfully.
    Message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH
    Iterations: 15
    Function evaluations: 16
Estimation time= 6.7 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
ASC_CAR                 0.2947825           nan           nan           nan    
ASC_TRAIN              -0.3132857           nan           nan           nan    
ASC_SM                  0.0185032           nan           nan           nan    
CO                     -1.2288284     0.1484780    -8.2761663      3.01e-16 ***
TT                     -2.8879939     0.2860638   -10.0956284      3.72e-23 ***
sd.TT                   2.7739818     0.3102202     8.9419754      1.22e-18 ***
---------------------------------------------------------------------------
Significance:  0 '***' 0.001 '**' 0

  self.stderr = np.sqrt(np.diag(self.covariance))


# Compare parameters:

In [39]:
print("{:>9} {:>20} {:>15} {:>13}".format("Estimate", "Jaxlogit-scipy", "Jaxlogit-jax", "Xlogit"))
print("-" * 76)
fmt = "{:9} {:18.7f} {:16.7f} {:15.7f}"
coeff_names = {'pf': 'pf_mean', 'sd.pf': 'pf_sd', 'cl': 'cl_mean', 'sd.cl': 'cl_sd', 'loc': 'loc_mean', 'sd.loc': 'loc_sd', 'wk': 'wk_mean', 'sd.wk': 'wk_sd', 'tod': 'tod_mean', 'sd.tod': 'tod_sd', 'seas': 'seas_mean', 'sd.seas': 'seas_sd'}
for i in range(len(model_x.coeff_)):
    name = model_jax.coeff_names[i]
    print(fmt.format(name[:13], 
                     model_scipy.coeff_[i], 
                     model_jax.coeff_[i], 
                     model_x.coeff_[i]))
print("-" * 76)

 Estimate       Jaxlogit-scipy    Jaxlogit-jax        Xlogit
----------------------------------------------------------------------------
ASC_SM             0.0751134        0.0753688       0.2947825
ASC_CAR            0.4162490        0.4149755      -0.3132857
ASC_TRAIN         -0.1913624       -0.1903443       0.0185032
TT                -4.2346783       -4.2349197      -1.2288284
CO                -3.6023248       -3.6027674      -2.8879939
sd.TT              3.9348555        3.9371494       2.7739818
----------------------------------------------------------------------------


# Predict
jaxlogit:

In [40]:
model = model_scipy 
config = ConfigData(
    panels=None if do_panel is False else df_test["ID"],
    n_draws=1000,
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=None,
    optim_method="L-BFGS-scipy",
)
config.init_coeff = init_coeff_scipy
prob_j_scipy = model.predict(df_test[varnames], varnames, df_train['alt'], df_train['custom_id'], randvars, config)

In [41]:
model = model_jax
config = ConfigData(
    panels=None if do_panel is False else df_test["ID"],
    n_draws=1000,
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=None,
    optim_method="L-BFGS-jax",
)
config.init_coeff = init_coeff_jax
prob_j_jax = model.predict(df_test[varnames], varnames, df_train['alt'], df_train['custom_id'], randvars, config)

xlogit:

In [42]:
_, prob_xx = model_x.predict(df_test[varnames], varnames, df_train['alt'], df_train['custom_id'], isvars=None, panels=None if do_panel is False else df_test["ID"], n_draws=1000, return_proba=True)

Compute the brier score:

In [43]:
# print(df_test['CHOICE'])
y = df_test['CHOICE']
y = y[y.index % 3 == 0]
# print(prob_j_scipy)
print(y)

14838       SM
0           SM
1218        SM
1179        SM
201         SM
         ...  
3666     TRAIN
12042    TRAIN
2481     TRAIN
9858     TRAIN
12816    TRAIN
Name: CHOICE, Length: 5414, dtype: object


In [44]:
print("{:>9} {:>9} {:>9}".format("Jaxlogit-scipy", "Jaxlogit-jax", "xlogit"))
print("-" * 48)
fmt = "{:9f} {:9f} {:9f}"
print(fmt.format(sklearn.metrics.brier_score_loss(y, prob_j_scipy),
                 sklearn.metrics.brier_score_loss(y, prob_j_jax),
                 sklearn.metrics.brier_score_loss(y, prob_xx)))
print("-" * 48)

Jaxlogit-scipy Jaxlogit-jax    xlogit
------------------------------------------------
 0.644901  0.644764  0.623109
------------------------------------------------
