# Summary of time taken and brier scores for jaxlogit, xlogit, and biogeme
Where the estimation is using draws = 600 (suboptimal but highest without running out of memory in biogeme), and training and test data is separated.

| | jaxlogit | xlogit | biogeme |
|---|---|---|---|
|Making Model | 37.7s | 16.9s | 4:15 |
|Estimating | 1.6s | 0.0s | 15.4s |
|Brier Score | 0.6345 | 0.6345 | 0.6345 |

# Setup

In [None]:
import pandas as pd
import numpy as np
import jax
import pathlib
import xlogit
from time import time

from jaxlogit.mixed_logit import MixedLogit, ConfigData

import biogeme.biogeme_logging as blog
import biogeme.biogeme as bio
from biogeme import models
from biogeme.expressions import Beta, Draws, log, MonteCarlo, PanelLikelihoodTrajectory
import biogeme.database as db
from biogeme.expressions import Variable

logger = blog.get_screen_logger()
logger.setLevel(blog.INFO)

#  64bit precision
jax.config.update("jax_enable_x64", True)

In [None]:
dataset = False # True for electricity. Set to False for artificial

# Get the full electricity dataset

Use for jaxlogit and xlogit. Adjustusting n_draws can improve accuracy, but Biogeme cannot handle 700 or more draws with this data set.

In [None]:
if dataset:
    varnames = ['pf', 'cl', 'loc', 'wk', 'tod', 'seas']
else:
    varnames = ['price', 'time', 'conven', 'comfort', 'meals', 'petfr', 'emipp', 'nonsig1', 'nonsig2', 'nonsig3']

Reshape the data so it can be passed to test_train_split in a wide format. Additionally, xlogit and jaxlogit require long format while biogeme requires a wide format.

In [None]:
if dataset:
    df_long = pd.read_csv(pathlib.Path.cwd().parent / "electricity_long.csv") # Electricity Dataset
else:
    df_long = pd.read_csv("https://raw.githubusercontent.com/arteagac/xlogit/master/examples/data/artificial_long.csv") # Artificial Long used by xlogit benchmarking

In [None]:
keys = ['id', 'chid', 'alt'] if dataset else ['id', 'alt']
choice_df = df_long.loc[df_long['choice'] == 1, keys]
choice_df = choice_df.rename(columns={'alt': 'choice'})
df_wide = df_long.pivot(index=keys[:-1], columns='alt', values=varnames)
df_wide.columns = [f'{var}_{alt}' for var, alt in df_wide.columns]
df_wide = df_wide.reset_index()
df_wide = df_wide.merge(
    choice_df,
    on=keys[:-1],
    how='inner',
    validate='one_to_one'
)

database = db.Database(dataset, df_wide)
if dataset:
    database.panel('id')

if dataset:
    y = df_long['choice']
    ids = df_long['chid']
    alts = df_long['alt']
    panels = df_long['id']
    randvars = {'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'}
else:
    y = df_long['choice']
    randvars = {'meals': 'n', 'petfr': 'n', 'emipp': 'n'}
    alts = df_long['alt']
    ids = df_long['id']
    panels = df_long['id']

# Memory Profiling

In [None]:
from memory_profiler import memory_usage

rdf_fit = pd.DataFrame(columns=["package", "memory"])

In [None]:
if dataset:
    df = pd.read_csv(pathlib.Path.cwd().parent / "electricity_long.csv") # get a clean version
else:
    df = pd.read_csv("https://raw.githubusercontent.com/arteagac/xlogit/master/examples/data/artificial_long.csv") # Artificial Long used by xlogit benchmarking
n_draws = 5000 # set amount of draws here
interval = 1.0 # set frequency of memory sample (per sec)

model_jax = MixedLogit()
model_x = xlogit.MixedLogit()

init_coeff = None

if dataset:
    X = {
        name: {
            j: Variable(f"{name}_{j}")
            for j in [1,2,3,4]
        }
        for name in varnames
    }

    alt_1 = Beta('alt_1', 0, None, None, 0)
    alt_2 = Beta('alt_2', 0, None, None, 0)
    alt_3 = Beta('alt_3', 0, None, None, 0)
    alt_4 = Beta('alt_4', 0, None, None, 1)

    pf_mean = Beta('pf_mean', 0, None, None, 0)
    pf_sd = Beta('pf_sd', 1, None, None, 0)
    cl_mean = Beta('cl_mean', 0, None, None, 0)
    cl_sd = Beta('cl_sd', 1, None, None, 0)
    loc_mean = Beta('loc_mean', 0, None, None, 0)
    loc_sd = Beta('loc_sd', 1, None, None, 0)
    wk_mean = Beta('wk_mean', 0, None, None, 0)
    wk_sd = Beta('wk_sd', 1, None, None, 0)
    tod_mean = Beta('tod_mean', 0, None, None, 0)
    tod_sd = Beta('tod_sd', 1, None, None, 0)
    seas_mean = Beta('seas_mean', 0, None, None, 0)
    seas_sd = Beta('seas_sd', 1, None, None, 0)

    pf_rnd = pf_mean + pf_sd * Draws('pf_rnd', 'NORMAL')
    cl_rnd = cl_mean + cl_sd * Draws('cl_rnd', 'NORMAL')
    loc_rnd = loc_mean + loc_sd * Draws('loc_rnd', 'NORMAL')
    wk_rnd = wk_mean + wk_sd * Draws('wk_rnd', 'NORMAL')
    tod_rnd = tod_mean + tod_sd * Draws('tod_rnd', 'NORMAL')
    seas_rnd = seas_mean + seas_sd * Draws('seas_rnd', 'NORMAL')

    choice = Variable('choice')

    V = {
        j: pf_rnd * X['pf'][j] + cl_rnd * X['cl'][j] + loc_rnd * X['loc'][j] + wk_rnd * X['wk'][j] + tod_rnd * X['tod'][j] + seas_rnd * X['seas'][j]
        for j in [1,2,3,4]
    }
else:
    df = pd.read_csv("https://raw.githubusercontent.com/arteagac/xlogit/master/examples/data/artificial_long.csv")
    df['choice'] = df['choice'].astype('str')
    mapping = {'1': 1, '2': 2, '3': 3}

    for k, v in mapping.items():
        df["aval_"+k] = np.ones(df.shape[0])
        df = df.replace({'choice': mapping})
    
    database = db.Database('artificial', df)

    # Fixed params
    b_price = Beta('b_price', 0, None, None, 1)
    b_time = Beta('b_time', 0, None, None, 1)
    b_conven = Beta('b_conven', 0, None, None, 1)
    b_comfort = Beta('b_comfort', 0, None, None, 1)
    b_nonsig1 = Beta('b_nonsig1', 0, None, None, 1)
    b_nonsig2 = Beta('b_nonsig2', 0, None, None, 1)
    b_nonsig3 = Beta('b_nonsig3', 0, None, None, 1)

    price_1 = Variable('price_1')
    price_2 = Variable('price_2')
    price_3 = Variable('price_3')

    time_1 = Variable('time_1')
    time_2 = Variable('time_2')
    time_3 = Variable('time_3')

    conven_1 = Variable('conven_1')
    conven_2 = Variable('conven_2')
    conven_3 = Variable('conven_3')

    comfort_1 = Variable('comfort_1')
    comfort_2 = Variable('comfort_2')
    comfort_3 = Variable('comfort_3')

    meals_1 = Variable('meals_1')
    meals_2 = Variable('meals_2')
    meals_3 = Variable('meals_3')

    petfr_1 = Variable('petfr_1')
    petfr_2 = Variable('petfr_2')
    petfr_3 = Variable('petfr_3')

    emipp_1 = Variable('emipp_1')
    emipp_2 = Variable('emipp_2')
    emipp_3 = Variable('emipp_3')

    nonsig1_1 = Variable('nonsig1_1')
    nonsig1_2 = Variable('nonsig1_2')
    nonsig1_3 = Variable('nonsig1_3')

    nonsig2_1 = Variable('nonsig2_1')
    nonsig2_2 = Variable('nonsig2_2')
    nonsig2_3 = Variable('nonsig2_3')

    nonsig3_1 = Variable('nonsig3_1')
    nonsig3_2 = Variable('nonsig3_2')
    nonsig3_3 = Variable('nonsig3_3')

    aval_1 = Variable('aval_1')
    aval_2 = Variable('aval_2')
    aval_3 = Variable('aval_3')

    choice = Variable('choice')

    # Random params
    u_meals = Beta('u_meals', 0, None, None, 1)
    u_petfr = Beta('u_petfr', 0, None, None, 1)
    u_emipp = Beta('u_emipp', 0, None, None, 1)
    sd_meals = Beta('sd_meals', 0, None, None, 1)
    sd_petfr = Beta('sd_petfr', 0, None, None, 1)
    sd_emipp = Beta('sd_emipp', 0, None, None, 1)

    b_meals = u_meals + sd_meals*Draws('b_meals', 'NORMAL')
    b_petfr = u_petfr + sd_petfr*Draws('b_petfr', 'NORMAL')
    b_emipp = u_emipp + sd_emipp*Draws('b_emipp', 'NORMAL')

    V1 = price_1*b_price+time_1*b_time+conven_1*b_conven+comfort_1*b_comfort+\
    meals_1*b_meals+petfr_1*b_petfr+emipp_1*b_emipp+nonsig1_1*b_nonsig1+\
        nonsig2_1*b_nonsig2+nonsig3_1*b_nonsig3
    V2 = price_2*b_price+time_2*b_time+conven_2*b_conven+comfort_2*b_comfort+\
    meals_2*b_meals+petfr_2*b_petfr+emipp_2*b_emipp+nonsig1_2*b_nonsig1+\
        nonsig2_2*b_nonsig2+nonsig3_2*b_nonsig3
    V3 = price_3*b_price+time_3*b_time+conven_3*b_conven+comfort_3*b_comfort+\
    meals_3*b_meals+petfr_3*b_petfr+emipp_3*b_emipp+nonsig1_3*b_nonsig1+\
        nonsig2_3*b_nonsig2+nonsig3_3*b_nonsig3

    V = {1: V1, 2: V2, 3: V3}
    av = {1: aval_1, 2: aval_2, 3: aval_3}

In [None]:
combos = [('jaxlogit_scipy', "L-BFGS-scipy"), ('jaxlogit_jax', "L-BFGS-jax")]
# no jax scipy batching, because it does not work.
combos_batched = [('jaxlogit_scipy_batched', "L-BFGS-scipy")]

In [None]:
for (name, method) in combos:
    print("Doing " + str(name))
    model_jax = MixedLogit()
    config = ConfigData(
        panels=df['id'],
        n_draws=n_draws,
        skip_std_errs=True,  # skip standard errors to speed up the example
        batch_size=None,
        optim_method=method,
    )

    mem = memory_usage((model_jax.fit, (df[varnames], y, varnames, alts, panels, randvars, config)), interval=interval, multiprocess=True)
    rdf_fit.loc[len(rdf_fit)] = [name, mem]

In [None]:
for (name, method) in combos_batched:
    model_jax = MixedLogit()
    config = ConfigData(
        panels=df['id'],
        n_draws=n_draws,
        skip_std_errs=True,  # skip standard errors to speed up the example
        batch_size=1077,
        optim_method=method,
    )

    mem = memory_usage((model_jax.fit, (df[varnames], df['choice'], varnames, df['alt'], ids, randvars, config)), interval=interval, multiprocess=True)
    rdf_fit.loc[len(rdf_fit)] = [name, mem]

In [None]:
mem_usage_x = memory_usage((model_x.fit, (df[varnames], df['choice'], varnames, df['alt'], ids, randvars), 
                          {"panels": df['id'], "n_draws": n_draws, "skip_std_errs": True, "batch_size": None, "optim_method": "L-BFGS-B"},), interval=interval, multiprocess=True)
rdf_fit.loc[len(rdf_fit)] = ["xlogit", mem_usage_x]

In [None]:
if dataset == True: # Does not work on artificial
    def biogeme_running(V, choice):
        prob = models.logit(V, None, choice)
        logprob = log(MonteCarlo(PanelLikelihoodTrajectory(prob)))

        the_biogeme = bio.BIOGEME(
            database, logprob, number_of_draws=600, seed=999, generate_yaml=False, generate_html=False
        )
        the_biogeme.model_name = 'model_b'
        results = the_biogeme.estimate()

    mem_usage_biogeme = memory_usage((biogeme_running, (V, choice)), interval=interval, multiprocess=True)
    rdf_fit.loc[len(rdf_fit)] = ["biogeme", mem_usage_biogeme]

Note: watch whether dumping, appending, or retrieving data and which datafile is being used

In [None]:
import json
import pandas as pd

# Dump all of rdf_fit
# list = rdf_fit.values.tolist()
# with open("comparison_data_memory_artificial.json", "w") as f:
#     json.dump(list, f)

# Append all of rdf_fit
# with open("comparison_data_memory.json", "r") as f:
#     fit_data = json.load(f)
# rdf = pd.DataFrame(fit_data)
# rdf.columns = ["package", "draws", "time"]
# rdf = pd.concat([rdf, rdf_fit], ignore_index=True)
# list = rdf.values.tolist()
# with open("comparison_data_memory.json", "w") as f:
#     json.dump(list, f)

# Retrieve all of rdf_fit
with open("comparison_data_memory.json", "r") as f:
    fit_data = json.load(f)
rdf = pd.DataFrame(fit_data)
rdf.columns = ["package", "memory"]
rdf = pd.concat([rdf], ignore_index=True)

print(rdf)

In [None]:
import matplotlib.pyplot as plt
names_dict = {"jaxlogit_scipy": "Jaxlogit using L-BFGS-scipy",
              "jaxlogit_scipy_b": "Jaxlogit using BFGS-scipy",
              "jaxlogit_jax": "Jaxlogit using L-BFGS-jax",
              "jaxlogit_jax_b": "Jaxlogit using BFGS-jax",
              "xlogit": "Xlogit",
              "biogeme": "Biogeme",
              "jaxlogit_scipy_batched": "Jaxlogit using L-BFGS-scipy and batching",
              "jaxlogit_scipy_b_batched": "Jaxlogit using BFGS-scipy and batching"
              }
include = ["jaxlogit_scipy", "jaxlogit_jax", "xlogit", "biogeme", "jaxlogit_scipy_batched"]

for index, row in rdf.iterrows():
    plt.plot(row["memory"], label=names_dict[row["package"]])
plt.xlim(left=0)
plt.xlabel("Time (samples)")
plt.ylabel("Memory (MB)")
plt.title("Memory Usage Over Time")
plt.legend(bbox_to_anchor=(1, 1))
plt.savefig("memory_comparison", bbox_inches="tight")
plt.show()

In [None]:
for index, row in rdf.iterrows():
    if row["package"] not in include:
        continue
    plt.plot(row["memory"], label=names_dict[row["package"]])
plt.xlim(left=0, right=30)
plt.ylim(top=15000)
plt.xlabel("Time (samples)")
plt.ylabel("Memory (MB)")
plt.title("Memory Usage Over Time")
plt.legend(bbox_to_anchor=(1, 0.6))
plt.savefig("memory_comparison_cropped", bbox_inches="tight")
plt.show()