# Example of running jaxlogit with batched draws

jaxlogit's default way of processing random draws for simulation is to generate them once at the beginning and then run calculate the loglikelihood at each step of the optimization routine with these draws. The size of the corresponding array is (number_of_observations x number_of_random_variables x  number_of_draws) which can get very large. In case tnis is too large for local memory, jaxlogit can dynamcially generate draws on each iteration. The advantage of this is that calculations can now be batched, i.e., processed on smaller subsets and then added up. This reduces memory load that the cost of runtime. Note that jax still calculates gradients so this method also has memory limits.

In [1]:
import os
os.chdir("/home/evelyn/projects_shared/jaxlogit")

import pandas as pd
import numpy as np
import jax

from jaxlogit.mixed_logit import MixedLogit, ConfigData
os.chdir("/home/evelyn/projects_shared/jaxlogit/examples")

In [2]:
#  64bit precision
jax.config.update("jax_enable_x64", True)

## Electricity Dataset

From xlogit's examples. Note we skip the calculation of std errors here to speed up test times.

In [3]:
df = pd.read_csv("https://raw.github.com/arteagac/xlogit/master/examples/data/electricity_long.csv")

In [4]:
n_obs = df['chid'].unique().shape[0]
n_vars = 6
n_draws = 1000

size_in_ram = (n_obs * n_vars * n_draws * 8) / (1024 ** 3)  # in GB

print(
    f"Data has {n_obs} observations, we use {n_vars} random variables in the model. We work in 64 bit precision, so each element is 8 bytes."
    + f" For {n_draws} draws, the array of draws is about {size_in_ram:.2f} GB."
)

varnames = ['pf', 'cl', 'loc', 'wk', 'tod', 'seas']

Data has 4308 observations, we use 6 random variables in the model. We work in 64 bit precision, so each element is 8 bytes. For 1000 draws, the array of draws is about 0.19 GB.


In [5]:
model = MixedLogit()

config = ConfigData(
    panels=df['id'],
    n_draws=1000,
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=None,
    optim_method="BFGS-scipy",
    maxiter=100,
)

res = model.fit(
    X=df[varnames],
    y=df['choice'],
    varnames=varnames,
    ids=df['chid'],
    alts=df['alt'],
    randvars={'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'},
    config=config
)
display(model.summary())



    Message: unknown
    Iterations: 70
    Function evaluations: 78
Estimation time= 53.4 seconds
---------------------------------------------------------------------------
Coefficient              Estimate      Std.Err.         z-val         P>|z|
---------------------------------------------------------------------------
pf                     -1.0206555     1.0000000    -1.0206555         0.307    
cl                     -0.2446836     1.0000000    -0.2446836         0.807    
loc                     2.3320637     1.0000000     2.3320637        0.0197 *  
wk                      1.6238631     1.0000000     1.6238631         0.104    
tod                    -9.7212850     1.0000000    -9.7212850      4.12e-22 ***
seas                   -9.6267158     1.0000000    -9.6267158      1.02e-21 ***
sd.pf                  -1.2174305     1.0000000    -1.2174305         0.224    
sd.cl                  -0.7125495     1.0000000    -0.7125495         0.476    
sd.loc                  1.7113082

None

In [None]:
model = MixedLogit()

config = ConfigData(
    panels=df['id'],
    n_draws=n_draws,
    skip_std_errs=True,  # skip standard errors to speed up the example
    batch_size=1077,  # should result in 4 batches
    optim_method="BFGS-scipy",
)

res = model.fit(
    X=df[varnames],
    y=df['choice'],
    varnames=varnames,
    ids=df['chid'],
    alts=df['alt'],
    randvars={'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'},
    config=config
)
display(model.summary())

In [None]:
model = MixedLogit()

config = ConfigData(
    panels=df['id'],
    n_draws=n_draws,
    skip_std_errs=True,  # skip standard errors to speed up the example
    optim_method="L-BFGS-B",
    batch_size=539,
)

res = model.fit(
    X=df[varnames],
    y=df['choice'],
    varnames=varnames,
    ids=df['chid'],
    alts=df['alt'],
    randvars={'pf': 'n', 'cl': 'n', 'loc': 'n', 'wk': 'n', 'tod': 'n', 'seas': 'n'},
    config=config
)
display(model.summary())