# Summary of time taken and brier scores for jaxlogit, xlogit, and biogeme
Where the estimation is using draws = 600 (suboptimal but highest without running out of memory in biogeme), and training and test data is separated.

| | jaxlogit | xlogit | biogeme |
|---|---|---|---|
|Making Model | 37.7s | 16.9s | 4:15 |
|Estimating | 1.6s | 0.0s | 15.4s |
|Brier Score | 0.6345 | 0.6345 | 0.6345 |

# Setup

In [None]:
import pandas as pd
import numpy as np
import jax
import pathlib
import xlogit
from time import time

from jaxlogit.mixed_logit import MixedLogit, ConfigData

import biogeme.biogeme_logging as blog
import biogeme.biogeme as bio
from biogeme import models
from biogeme.expressions import Beta, Draws, log, MonteCarlo, PanelLikelihoodTrajectory
import biogeme.database as db
from biogeme.expressions import Variable

logger = blog.get_screen_logger()
logger.setLevel(blog.INFO)

#  64bit precision
jax.config.update("jax_enable_x64", True)

# Get the full electricity dataset

Use for jaxlogit and xlogit. Adjustusting n_draws can improve accuracy, but Biogeme cannot handle 700 or more draws with this data set.

In [None]:
dataset = False # True for electricity. Set to False for artificial

In [None]:
if dataset:
    varnames = ['pf', 'cl', 'loc', 'wk', 'tod', 'seas']
else:
    varnames = ['price', 'time', 'conven', 'comfort', 'meals', 'petfr', 'emipp', 'nonsig1', 'nonsig2', 'nonsig3']

quick_draws_biogeme = [100, 200, 400, 500, 600]
quick_draws_extended = [700, 800, 900, 1000, 1500, 2000, 2500, 3000, 4000, 5000]
trials = 3 # number of times to run each thing

In [None]:
rdf_fit = pd.DataFrame(columns=["package", "draws", "time"])
rdf_predict = pd.DataFrame(columns=["package", "time"])

Reshape the data so it can be passed to test_train_split in a wide format. Additionally, xlogit and jaxlogit require long format while biogeme requires a wide format.

In [None]:
if dataset:
    df_long = pd.read_csv(pathlib.Path.cwd().parent / "electricity_long.csv") # Electricity Dataset
else:
    df_long = pd.read_csv("https://raw.githubusercontent.com/arteagac/xlogit/master/examples/data/artificial_long.csv") # Artificial Long used by xlogit benchmarking

In [None]:
keys = ['id', 'chid', 'alt'] if dataset else ['id', 'alt']
choice_df = df_long.loc[df_long['choice'] == 1, keys]
choice_df = choice_df.rename(columns={'alt': 'choice'})
df_wide = df_long.pivot(index=keys[:-1], columns='alt', values=varnames)
df_wide.columns = [f'{var}_{alt}' for var, alt in df_wide.columns]
df_wide = df_wide.reset_index()
df_wide = df_wide.merge(
    choice_df,
    on=keys[:-1],
    how='inner',
    validate='one_to_one'
)

database = db.Database(dataset, df_wide)
if dataset:
    database.panel('id')

if dataset:
    y = df_long['choice']
    ids = df_long['chid']
    alts = df_long['alt']
    panels = df_long['id']
    randvars = {'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'}
else:
    y = df_long['choice']
    randvars = {'meals': 'n', 'petfr': 'n', 'emipp': 'n'}
    alts = df_long['alt']
    ids = df_long['id']
    panels = df_long['id']

In [None]:
model_jax_scipy_batched = MixedLogit()
model_jax_scipy_l = MixedLogit()
model_jax_jax_l = MixedLogit()
model_x = xlogit.MixedLogit()

init_coeff = None

Biogeme setup:

In [None]:
if dataset:
    X = {
        name: {
            j: Variable(f"{name}_{j}")
            for j in [1,2,3,4]
        }
        for name in varnames
    }

    alt_1 = Beta('alt_1', 0, None, None, 0)
    alt_2 = Beta('alt_2', 0, None, None, 0)
    alt_3 = Beta('alt_3', 0, None, None, 0)
    alt_4 = Beta('alt_4', 0, None, None, 1)

    pf_mean = Beta('pf_mean', 0, None, None, 0)
    pf_sd = Beta('pf_sd', 1, None, None, 0)
    cl_mean = Beta('cl_mean', 0, None, None, 0)
    cl_sd = Beta('cl_sd', 1, None, None, 0)
    loc_mean = Beta('loc_mean', 0, None, None, 0)
    loc_sd = Beta('loc_sd', 1, None, None, 0)
    wk_mean = Beta('wk_mean', 0, None, None, 0)
    wk_sd = Beta('wk_sd', 1, None, None, 0)
    tod_mean = Beta('tod_mean', 0, None, None, 0)
    tod_sd = Beta('tod_sd', 1, None, None, 0)
    seas_mean = Beta('seas_mean', 0, None, None, 0)
    seas_sd = Beta('seas_sd', 1, None, None, 0)

    pf_rnd = pf_mean + pf_sd * Draws('pf_rnd', 'NORMAL')
    cl_rnd = cl_mean + cl_sd * Draws('cl_rnd', 'NORMAL')
    loc_rnd = loc_mean + loc_sd * Draws('loc_rnd', 'NORMAL')
    wk_rnd = wk_mean + wk_sd * Draws('wk_rnd', 'NORMAL')
    tod_rnd = tod_mean + tod_sd * Draws('tod_rnd', 'NORMAL')
    seas_rnd = seas_mean + seas_sd * Draws('seas_rnd', 'NORMAL')

    choice = Variable('choice')

    V = {
        j: pf_rnd * X['pf'][j] + cl_rnd * X['cl'][j] + loc_rnd * X['loc'][j] + wk_rnd * X['wk'][j] + tod_rnd * X['tod'][j] + seas_rnd * X['seas'][j]
        for j in [1,2,3,4]
    }
else:
    df = pd.read_csv("https://raw.githubusercontent.com/arteagac/xlogit/master/examples/data/artificial_long.csv")
    print(df.columns)
    df['choice'] = df['choice'].astype('str')
    mapping = {'1': 1, '2': 2, '3': 3}

    for k, v in mapping.items():
        df["aval_"+k] = np.ones(df.shape[0])
        df = df.replace({'choice': mapping})
    
    database = db.Database('artificial', df)

    # Fixed params
    b_price = Beta('b_price', 0, None, None, 0)
    b_time = Beta('b_time', 0, None, None, 0)
    b_conven = Beta('b_conven', 0, None, None, 0)
    b_comfort = Beta('b_comfort', 0, None, None, 0)
    b_nonsig1 = Beta('b_nonsig1', 0, None, None, 0)
    b_nonsig2 = Beta('b_nonsig2', 0, None, None, 0)
    b_nonsig3 = Beta('b_nonsig3', 0, None, None, 0)

    price_1 = Variable('price_1')
    price_2 = Variable('price_2')
    price_3 = Variable('price_3')

    time_1 = Variable('time_1')
    time_2 = Variable('time_2')
    time_3 = Variable('time_3')

    conven_1 = Variable('conven_1')
    conven_2 = Variable('conven_2')
    conven_3 = Variable('conven_3')

    comfort_1 = Variable('comfort_1')
    comfort_2 = Variable('comfort_2')
    comfort_3 = Variable('comfort_3')

    meals_1 = Variable('meals_1')
    meals_2 = Variable('meals_2')
    meals_3 = Variable('meals_3')

    petfr_1 = Variable('petfr_1')
    petfr_2 = Variable('petfr_2')
    petfr_3 = Variable('petfr_3')

    emipp_1 = Variable('emipp_1')
    emipp_2 = Variable('emipp_2')
    emipp_3 = Variable('emipp_3')

    nonsig1_1 = Variable('nonsig1_1')
    nonsig1_2 = Variable('nonsig1_2')
    nonsig1_3 = Variable('nonsig1_3')

    nonsig2_1 = Variable('nonsig2_1')
    nonsig2_2 = Variable('nonsig2_2')
    nonsig2_3 = Variable('nonsig2_3')

    nonsig3_1 = Variable('nonsig3_1')
    nonsig3_2 = Variable('nonsig3_2')
    nonsig3_3 = Variable('nonsig3_3')

    aval_1 = Variable('aval_1')
    aval_2 = Variable('aval_2')
    aval_3 = Variable('aval_3')

    # Random params
    u_meals = Beta('u_meals', 0, None, None, 0)
    u_petfr = Beta('u_petfr', 0, None, None, 0)
    u_emipp = Beta('u_emipp', 0, None, None, 0)
    sd_meals = Beta('sd_meals', 0, None, None, 0)
    sd_petfr = Beta('sd_petfr', 0, None, None, 0)
    sd_emipp = Beta('sd_emipp', 0, None, None, 0)

    b_meals = u_meals + sd_meals*Draws('b_meals', 'NORMAL')
    b_petfr = u_petfr + sd_petfr*Draws('b_petfr', 'NORMAL')
    b_emipp = u_emipp + sd_emipp*Draws('b_emipp', 'NORMAL')

    V1 = price_1*b_price+time_1*b_time+conven_1*b_conven+comfort_1*b_comfort+\
    meals_1*b_meals+petfr_1*b_petfr+emipp_1*b_emipp+nonsig1_1*b_nonsig1+\
        nonsig2_1*b_nonsig2+nonsig3_1*b_nonsig3
    V2 = price_2*b_price+time_2*b_time+conven_2*b_conven+comfort_2*b_comfort+\
    meals_2*b_meals+petfr_2*b_petfr+emipp_2*b_emipp+nonsig1_2*b_nonsig1+\
        nonsig2_2*b_nonsig2+nonsig3_2*b_nonsig3
    V3 = price_3*b_price+time_3*b_time+conven_3*b_conven+comfort_3*b_comfort+\
    meals_3*b_meals+petfr_3*b_petfr+emipp_3*b_emipp+nonsig1_3*b_nonsig1+\
        nonsig2_3*b_nonsig2+nonsig3_3*b_nonsig3

    V = {1: V1, 2: V2, 3: V3}
    av = {1: aval_1, 2: aval_2, 3: aval_3}

# Make the models
Jaxlogit:

In [None]:
combos = [('jaxlogit_scipy', "L-BFGS-scipy"), ('jaxlogit_jax', "L-BFGS-jax")]
combos_batched = [('jaxlogit_scipy_batched', "L-BFGS-scipy")]

In [None]:
for (name, method) in combos:
    for i in range(trials):
        print("Using " + str(name) + " trial " + str(i + 1))
        for draw in (quick_draws_biogeme + quick_draws_extended):
            model_jax = MixedLogit()
            print("starting " + str(draw))
            start_time = time()
            config = ConfigData(
                panels=panels,
                n_draws=draw,
                skip_std_errs=True,  # skip standard errors to speed up the example
                batch_size=None,
                optim_method=method,
            )
            init_coeff = None
            
            model_jax.fit(
                X=df_long[varnames],
                y=df_long['choice'],
                varnames=varnames,
                ids=ids,
                alts=df_long['alt'],
                randvars=randvars,
                config=config
            )
            ellapsed = time() - start_time
            rdf_fit.loc[len(rdf_fit)] = [name, draw, ellapsed]

In [None]:
for (name, method) in combos_batched:
    for i in range(trials):
        print("Using " + str(name) + " trial " + str(i + 1))
        for draw in (quick_draws_biogeme + quick_draws_extended):
            model_jax = MixedLogit()
            print("starting " + str(draw))
            start_time = time()
            config = ConfigData(
                panels=panels,
                n_draws=draw,
                skip_std_errs=True,  # skip standard errors to speed up the example
                batch_size=539,
                optim_method=method,
            )
            init_coeff = None
            
            model_jax.fit(
                X=df_long[varnames],
                y=df_long['choice'],
                varnames=varnames,
                ids=ids,
                alts=df_long['alt'],
                randvars=randvars,
                config=config
            )
            ellapsed = time() - start_time
            rdf_fit.loc[len(rdf_fit)] = [name, draw, ellapsed]

xlogit:

In [None]:
for draw in (quick_draws_biogeme + quick_draws_extended):
    for i in range(trials):
        print("Using " + "xlogit" + " trial " + str(i + 1))
        start_time = time()
        model_x.fit(
            X=df_long[varnames],
            y=y,
            varnames=varnames,
            ids=ids,
            alts=alts,
            randvars=randvars,
            panels=panels,
            n_draws=draw,
            skip_std_errs=True,  # skip standard errors to speed up the example
            batch_size=None,
            optim_method="L-BFGS-B",
        )
        ellapsed = time() - start_time
        rdf_fit.loc[len(rdf_fit)] = ['xlogit', draw, ellapsed]

Biogeme:

In [None]:
if dataset: # Biogeme does not work with artifical dataset
    for draw in quick_draws_biogeme:
        for i in range(trials):
            print("Using " + "biogeme" + " trial " + str(i + 1))
            print(f"starting draw {draw}")
            start_time = time()
            prob = models.logit(V, None, choice)
            logprob = log(MonteCarlo(PanelLikelihoodTrajectory(prob)))

            the_biogeme = bio.BIOGEME(
                database, logprob, number_of_draws=draw, seed=999, generate_yaml=False, generate_html=False
            )
            the_biogeme.model_name = 'model_b'
            results = the_biogeme.estimate()
            ellapsed = time() - start_time
            rdf_fit.loc[len(rdf_fit)] = ['biogeme', draw, ellapsed]

# Graphs

Note: watch whether dumping, appending, or retrieving data and which datafile is being used

In [None]:
import json
import pandas as pd

# # Dump all of rdf_fit
list = rdf_fit.values.tolist()
with open("comparison_data_timing_artificial.json", "w") as f:
    json.dump(list, f)

# Append all of rdf_fit
# with open("comparison_data_timing.json", "r") as f:
#     fit_data = json.load(f)
# rdf = pd.DataFrame(fit_data)
# rdf.columns = ["package", "draws", "time"]
# rdf = pd.concat([rdf, rdf_fit], ignore_index=True)
# list = rdf.values.tolist()
# with open("comparison_data_timing.json", "w") as f:
#     json.dump(list, f)

# Retrieve all of rdf_fit
# with open("comparison_data_timing.json", "r") as f:
#     fit_data = json.load(f)
# rdf = pd.DataFrame(fit_data)
# rdf.columns = ["package", "draws", "time"]

In [None]:
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams.update({'font.size': 14,
                            'axes.spines.right': False,
                            'axes.spines.top': False})
libs = ['jaxlogit_scipy', 'jaxlogit_scipy_b', 'jaxlogit_jax', 'jaxlogit_jax_b', 'jaxlogit_scipy_batched', 'jaxlogit_scipy_b_batched', 'xlogit', 'biogeme']

In [None]:
# Find the minimum times
from itertools import product

def minimum(rdf):
    temp = pd.DataFrame(columns=["package", "draws", "time"])
    packages = rdf["package"].unique()
    draws = rdf["draws"].unique()
    for (package, draw) in list(product(packages, draws)):
        data = rdf.loc[(rdf["package"] == package) & (rdf["draws"] == draw)]["time"]
        if not data.empty:
            temp.loc[len(temp)] = [package, draw, data.min()]

    return temp

rdf = minimum(rdf)

In [None]:
names_dict = {"jaxlogit_scipy": "Jaxlogit using L-BFGS-scipy",
              "jaxlogit_scipy_b": "Jaxlogit using BFGS-scipy",
              "jaxlogit_jax": "Jaxlogit using L-BFGS-jax",
              "jaxlogit_jax_b": "Jaxlogit using BFGS-jax",
              "xlogit": "Xlogit",
              "biogeme": "Biogeme",
              "jaxlogit_scipy_batched": "Jaxlogit using L-BFGS-scipy and batching",
              "jaxlogit_scipy_b_batched": "Jaxlogit using BFGS-scipy and batching"
              }

def plot(df, name):
    for lib in libs:
        rdf = df.loc[df["package"] == lib]
        plt.plot(rdf["draws"], rdf["time"], label=names_dict[lib])
    plt.legend(bbox_to_anchor=(1, 1))
    plt.xlabel("Random draws")
    plt.ylabel("Time (Seconds)")
    plt.title("Estimation time")
    plt.savefig(name, bbox_inches="tight")
    plt.show()
    plt.close()

plot(rdf, "fit_estimation_time")