# Specify tables and columns

In [None]:
import numpy as np
import pandas as pd
import orca
import os; os.chdir('../')
import warnings; warnings.simplefilter('ignore')

# Set data directory

d = '/home/data/fall_2018/'

if 'data_directory' in orca.list_injectables():
    d = orca.get_injectable('data_directory')
    
#from scripts import datasources, models, variables

In [None]:
# override orca persons and students tables for estimation
# 
@orca.table(cache=True)
def persons():
    df = pd.read_csv(
        d + 'chts_persons_w_zone_ids.csv',
        index_col = ["SAMPN", "PERNO"]
    )
    return df

#persons_chts = persons_df1.join(persons_df2, how="left")

persons = orca.get_table('persons').to_frame()
students = persons.loc[persons['STUDE'].isin([1, 2])# full time & part time students
                 & persons['SCHOL'].isin([3,  # Kindergarten to grade 8
                                          4,  # Grades 9 to 12 
                                          6,  # 2-year college (community college) 
                                          7,  # 4-year college or university 
                                          8]) # Graduate school / Professional 
                 & (~persons['SNAME_lookup'].isna()) 
                 & (persons['SNAME_lookup'] != "DK/RF")]
len(students)

schools_raw = students.groupby(by=["SCHOL", "SNAME_lookup", "SZIP_lookup"]) \
                      .size().reset_index(name='enrollment')
# There are 3505 unique "schools", most of them have 1 student in the CHTS sample
# For now, keep only schools with 3+ students in the CHTS sample
schools = schools_raw.loc[schools_raw['enrollment'] >= 3].reset_index(drop=True) #.drop(columns=["enrollment"])
schools.index.name = "school_id"
schools.reset_index(inplace=True)

students = pd.merge(students, schools.drop(columns=["enrollment"]), 
                    how="inner", on=["SCHOL", "SNAME_lookup", "SZIP_lookup"])
schools = schools[['school_id', 'enrollment']]
#students = students.loc[~students['school_id'].isna()]

orca.add_table('students', students)
orca.add_table('schools', schools)
#len(students)
#len(schools)

In [None]:
## DIAGNOSTICS
#schools
#students[["index", "school_id"]]

In [None]:
# To be moved to scripts/variables.py

@orca.column('students', 'is_college_student', cache=True)
def is_college_student(students):
    is_college_map = {3: 0, # Kindergarten to grade 8
                      4: 0, # Grades 9 to 12 
                      6: 1, # 2-year college (community college) 
                      7: 1, # 4-year college or university 
                      8: 1} # Graduate school / Professional 
    return students.SCHOL.map(is_college_map)

#@orca.column('students', 'school_id', cache=True)
#def school_id(students, schools):
#    misc.reindex()
#    return pd.merge(students, schools, how="left", on=["SCHOL", "SNAME_lookup", "SZIP_lookup"])["school_id"]

@orca.column("students")
def zone_id_school(students, persons):
    fake_zone_id_school = np.random.choice(np.unique(persons.zone_id_home), len(students.school_id))
    return fake_zone_id_school


## load skims for interaction terms

In [None]:
# Travel Time
skims = pd.read_csv(d + '/mtc_skims/TimeSkimsDatabaseAM.csv')
interaction_terms_tt = skims[['orig', 'dest', 'da', 'wTrnW']].rename(
    columns={'orig': 'zone_id_home', 'dest': 'zone_id_school', 'da': 'tt_da', 'wTrnW': 'tt_wTrnW'})
interaction_terms_tt.set_index(['zone_id_home', 'zone_id_school'], inplace=True)
#interaction_terms.to_csv('./data/WLCM_interaction_terms_tt.csv')

# Distance
skims = pd.read_csv(d + '/mtc_skims/DistanceSkimsDatabaseAM.csv')
interaction_terms_dist = skims[['orig', 'dest', 'da', 'walk']].rename(
    columns={'orig': 'zone_id_home', 'dest': 'zone_id_school', 'da': 'dist_da', 'walk': 'dist_walk'})
interaction_terms_dist.set_index(['zone_id_home', 'zone_id_school'], inplace=True)
#interaction_terms_dist.to_csv('./data/WLCM_interaction_terms_dist.csv')

# Cost
skims = pd.read_csv(d + '/mtc_skims/CostSkimsDatabaseAM.csv')
interaction_terms_cost = skims[['orig', 'dest', 'daToll', 'wTrnW']].rename(
    columns={'orig': 'zone_id_home', 'dest': 'zone_id_school', 'daToll': 'cost_da_toll', 'wTrnW': 'cost_wTrnW'})
interaction_terms_cost.set_index(['zone_id_home', 'zone_id_school'], inplace=True)
#interaction_terms_cost.to_csv('./data/WLCM_interaction_terms_cost.csv')


In [None]:
from choicemodels.tools import MergedChoiceTable

students = orca.get_table("students")
students.zone_id_home
students.zone_id_school
students = students.to_frame()

schools = orca.get_table("schools").to_frame()

In [None]:
#%%time
#%memit
mct = MergedChoiceTable(students, schools, chosen_alternatives='school_id',
                        sample_size=10, interaction_terms=[
                            interaction_terms_tt, interaction_terms_dist, interaction_terms_cost])

# Configure models

In [None]:
from urbansim_templates import modelmanager
from urbansim_templates.models import SmallMultinomialLogitStep, LargeMultinomialLogitStep, SegmentedLargeMultinomialLogitStep

modelmanager.initialize()

m0 = LargeMultinomialLogitStep(
    constrained_choices=True,
    alt_sample_size=10
)

In [None]:
m0.model_expression = (
    'tt_da'
)

m0.fit(mct)

In [None]:
from urbansim_templates.models import SegmentedLargeMultinomialLogitStep

m = SegmentedLargeMultinomialLogitStep(
      defaults = m0,
      name = "school-choice-model",
      segmentation_column = "is_college_student",
      )

In [None]:
m.model_expression = (
    'tt_da'
)
# this is not working yet, as m.fit_all() does not accept mct argument
m.fit_all(mct)

In [None]:
m.name = 'School-Choice-Model'
mm.register(m)