In [None]:
import pandas as pd
import numpy as np
import jax

import os
os.chdir("/home/evelyn/projects_shared/jaxlogit")

from jaxlogit.mixed_logit import MixedLogit, ConfigData

os.chdir("/home/evelyn/projects_shared/jaxlogit/examples")

#  64bit precision
jax.config.update("jax_enable_x64", True)

In [None]:
df_wide = pd.read_table("http://transp-or.epfl.ch/data/swissmetro.dat", sep='\t')

# Keep only observations for commute and business purposes that contain known choices
df_wide = df_wide[(df_wide['PURPOSE'].isin([1, 3]) & (df_wide['CHOICE'] != 0))]

df_wide['custom_id'] = np.arange(len(df_wide))  # Add unique identifier
df_wide['CHOICE'] = df_wide['CHOICE'].map({1: 'TRAIN', 2:'SM', 3: 'CAR'})
df_wide

from jaxlogit.utils import wide_to_long

df = wide_to_long(df_wide, id_col='custom_id', alt_name='alt', sep='_',
                  alt_list=['TRAIN', 'SM', 'CAR'], empty_val=0,
                  varying=['TT', 'CO', 'HE', 'AV', 'SEATS'], alt_is_prefix=True)
df

df['ASC_TRAIN'] = np.ones(len(df))*(df['alt'] == 'TRAIN')
df['ASC_CAR'] = np.ones(len(df))*(df['alt'] == 'CAR')
df['TT'], df['CO'] = df['TT']/100, df['CO']/100  # Scale variables
annual_pass = (df['GA'] == 1) & (df['alt'].isin(['TRAIN', 'SM']))
df.loc[annual_pass, 'CO'] = 0  # Cost zero for pass holders

In [None]:
nests = {'public transport': ['TRAIN', 'SM']}

In [None]:
varnames=['ASC_CAR', 'ASC_TRAIN', 'CO', 'TT']
model = MixedLogit()

config = ConfigData(
    n_draws=1500, 
    avail=(df['AV']),
    panels=(df["ID"]),
)

res = model.fit(
    df[varnames],
    df['CHOICE'],
    varnames,
    df['alt'],
    df['custom_id'],
    {'TT': 'n'},
    config,
    nests=nests
)
model.summary()