In [1]:
import sys
sys.path.append("models/")

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from setup import *
from dataloader import SurveyDataset
import mnl

%load_ext autoreload
%autoreload 2

In [2]:
data_version = '1571'

In [3]:
tp = pd.read_csv(data_dir+"trips.csv")
n_alts = 4

In [4]:
print(tp['mode'].value_counts()/len(tp))

2    0.713060
1    0.132001
4    0.111893
3    0.043046
Name: mode, dtype: float64


In [5]:
tp['morning'] = (tp['dep_hour'] > 6) & (tp['dep_hour'] < 10)
tp['afternoon'] = (tp['dep_hour'] > 15) & (tp['dep_hour'] < 19)
tp['morning'] = tp['morning'].astype(int)
tp['afternoon'] = tp['afternoon'].astype(int)

def normalize_features(df, cols):
    for c in cols:
        df[c] = df[c]/df[c].max()
    return df

In [6]:
tp['const'] = 1

In [7]:
ct_filter = pd.read_csv(data_dir+"census_tracts_filtered-"+data_version+".csv")
unique_ct = ct_filter['geoid'].to_numpy()

In [8]:
len(unique_ct)

1571

In [9]:
trip_filter = []
for t1, t2 in zip(tp['tract_1'], tp['tract_2']):
    if sum(unique_ct == t1) == 1 and sum(unique_ct == t2) == 1:
        trip_filter.append(True)
    else:
        trip_filter.append(False)
trip_filter = np.array(trip_filter)

In [10]:
x = tp[['const','morning','afternoon','companion', 'distance', 
         'from_home', 'to_home', 'purp_work', 'purp_school', 'purp_errand', 'purp_recreation', 
         'ontime_important', '12_18yrs', '18_25yrs', '25_55yrs', '55+yrs', 
         'disability', 'educ_col', 'educ_grad', 
         'race_white', 'race_black', 'race_asian', 
         'male', 'female', 
         'emply_park', 'emply_transit', 'emply_veh', 'emply_wfh', 'emply_flex', 'emply_hours', 
         'license', 'person_trips', 'person_transit', 'person_freq_transit', 
         'hh_inc_0_30', 'hh_inc_30_60', 'hh_inc_60_100', 'hh_inc_100_150', 'hh_inc_150', 
         'avg_pr_veh', 'home_own', 'home_house', 'home_condo']].to_numpy()[trip_filter]

y = tp['mode'].astype(int).to_numpy()[trip_filter] - 1
x_train, x_test, y_train, y_test = train_test_split(x,y,test_size=0.2, random_state=42)

In [11]:
x.shape

(79929, 43)

In [12]:
trainset = SurveyDataset(torch.tensor(x_train, dtype=torch.float), torch.tensor(y_train, dtype=torch.long))
trainloader = DataLoader(trainset, batch_size=256, shuffle=True)

testset = SurveyDataset(torch.tensor(x_test, dtype=torch.float), torch.tensor(y_test, dtype=torch.long))
testloader = DataLoader(testset, batch_size=len(testset), shuffle=True)

In [13]:
loss_fn = nn.CrossEntropyLoss()

model = mnl.MNL(n_alts=n_alts, n_features=x.shape[-1])
# model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=0)

ref1 = 0
ref2 = 0
converged = 0

for epoch in range(300):
    loss_ = 0
    correct = 0
    for batch, (x_batch, y_batch) in enumerate(trainloader):
        # Compute prediction and loss
        util = model(x_batch)
        loss = loss_fn(util, y_batch)
        loss_ += loss.item()*len(y_batch)
        
        pred = torch.argmax(util, dim=1)
        correct += torch.sum(pred == y_batch)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


    if batch % 1 == 0:
        print(f"[epoch: {epoch:>3d}] Train loss: {loss_/len(trainset):.4f} accuracy: {correct/len(trainset):.3f}")

    correct = 0
    loss_ = 0
    for batch, (x_batch, y_batch) in enumerate(testloader):
        util = model(x_batch)
        loss = loss_fn(util, y_batch)
        loss_ += loss.item()*len(y_batch)
        pred = torch.argmax(util, dim=1)
        correct += torch.sum(pred == y_batch)
        
    print(f"[epoch: {epoch:>3d}] Test loss: {loss_/len(testset):.4f} accuracy: {correct/len(testset):.3f}")
    
    if epoch > 15:
        if (np.abs(loss_ - ref1)/ref1<0.001) & (np.abs(loss_ - ref2)/ref2<0.001):
            print("Early stopping at epoch", epoch)
            converged = 1
            break
        if (ref1 < loss_) & (ref1 < ref2):
            print("Diverging. stop.")
            break
            
    ref2 = ref1
    ref1 = loss_
        

[epoch:   0] Train loss: 0.9531 accuracy: 0.660
[epoch:   0] Test loss: 0.7693 accuracy: 0.721
[epoch:   1] Train loss: 0.7201 accuracy: 0.726
[epoch:   1] Test loss: 0.6654 accuracy: 0.747
[epoch:   2] Train loss: 0.6393 accuracy: 0.756
[epoch:   2] Test loss: 0.6069 accuracy: 0.774
[epoch:   3] Train loss: 0.5927 accuracy: 0.780
[epoch:   3] Test loss: 0.5717 accuracy: 0.794
[epoch:   4] Train loss: 0.5629 accuracy: 0.796
[epoch:   4] Test loss: 0.5474 accuracy: 0.804
[epoch:   5] Train loss: 0.5419 accuracy: 0.805
[epoch:   5] Test loss: 0.5302 accuracy: 0.811
[epoch:   6] Train loss: 0.5262 accuracy: 0.811
[epoch:   6] Test loss: 0.5168 accuracy: 0.816
[epoch:   7] Train loss: 0.5139 accuracy: 0.816
[epoch:   7] Test loss: 0.5061 accuracy: 0.820
[epoch:   8] Train loss: 0.5043 accuracy: 0.820
[epoch:   8] Test loss: 0.4983 accuracy: 0.824
[epoch:   9] Train loss: 0.4962 accuracy: 0.823
[epoch:   9] Test loss: 0.4907 accuracy: 0.826
[epoch:  10] Train loss: 0.4896 accuracy: 0.826
[e

In [14]:
for i in model.named_parameters():
    print(i)

('beta.weight', Parameter containing:
tensor([[ 1.7323e-01, -1.3745e-01, -5.9457e-03, -3.2567e-02, -5.8126e-01,
         -5.0740e-02,  5.7224e-02,  3.3084e-01,  3.7294e-01, -4.8436e-02,
          3.1775e-01, -1.5346e-01,  2.3250e-02,  3.8133e-01,  7.3825e-02,
          1.4426e-01, -3.1204e-01,  3.3273e-01,  4.1470e-01,  2.4266e-01,
         -6.6224e-02,  2.5318e-02,  2.3201e-01,  1.7770e-01, -3.1163e-01,
          3.7078e-01, -2.0950e-01,  8.2269e-03,  3.3414e-02,  8.2250e-02,
         -1.5508e-01,  1.2813e-03, -1.8687e-02, -1.9117e-01,  1.9198e-01,
          8.4143e-02,  6.8167e-02,  5.5181e-02,  6.9886e-02, -6.6158e-01,
         -7.7369e-02, -7.7034e-02,  4.9229e-01],
        [-2.3005e-01, -1.1370e-01,  7.1415e-03,  2.5589e-01,  3.7204e-01,
          1.5575e-01,  5.6162e-02, -3.0736e-01, -5.1040e-01,  2.7782e-01,
         -1.7180e-02,  4.4394e-03, -3.5344e-01, -3.2084e-01,  7.0169e-03,
         -7.5351e-02, -1.6085e-02, -3.1900e-01, -3.5089e-01, -1.2440e-02,
         -1.3163e-01,  1.