In [47]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt # this is used for the plot the graph 

from sklearn.decomposition import PCA
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import OneHotEncoder

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader

%matplotlib inline

import random 
random.seed(42)

In [48]:
csv_path = "training_dataset.csv"

# 0. First, let's understand our data...

In [49]:
df = pd.read_csv(csv_path) # load the pandas dataframe
initial_cols_to_drop = ["Unnamed: 0","Unnamed: 0.1", "period", "test"] 
for col_name in initial_cols_to_drop: # drops columns that aren't supposed to be in dataset
    try:
        df = df.drop(columns=[col_name])
    except:
        continue
#df = df.rename(columns={"Unnamed: 0.1": "TODO_FIND_COLUMN_NAME_2"})
display(df.head())

Unnamed: 0,dt,weekday,year,id_driver,id_carrier_number,dim_carrier_type,dim_carrier_company_name,home_base_city,home_base_state,carrier_trucks,...,marketplace_loads_otr,marketplace_loads_atlas,marketplace_loads,brokerage_loads_otr,brokerage_loads_atlas,brokerage_loads,total_loads,date,recent_date,label
0,2018-08-23,4,2018,13577,C0092604,Fleet,US HONG CORP,City of Industry,CA,"[""dryvan""]",...,0,0,0,2,0,2,2,2018-08-23,2018-08-27,0
1,2018-05-11,5,2018,7066,C0090412,Fleet,Carlos Flores,Los Angeles,CA,"[""dryvan""]",...,134,0,134,186,0,186,320,2018-05-11,2019-09-05,0
2,2019-03-23,6,2019,16776,C0093729,Fleet,AZIEL INC,Ontario,CA,"[""boxtruck""]",...,11,0,11,97,0,97,108,2019-03-23,2019-10-17,0
3,2019-07-03,3,2019,11085,U0099848,Owner Operator,Expert Carriers,Long Beach,CA,"[""poweronly""]",...,15,0,15,179,0,179,194,2019-07-03,2019-12-28,0
4,2020-11-07,6,2020,14644,U0102863,Owner Operator,Edgar & daughters,South El Monte,CA,"[""poweronly""]",...,17,279,296,0,8,8,304,2020-11-07,2021-02-17,0


In [50]:
df.describe()

Unnamed: 0,weekday,year,id_driver,num_trucks,days_signup_to_approval,loads,marketplace_loads_otr,marketplace_loads_atlas,marketplace_loads,brokerage_loads_otr,brokerage_loads_atlas,brokerage_loads,total_loads,label
count,83414.0,83414.0,83414.0,83344.0,71128.0,83414.0,83414.0,83414.0,83414.0,83414.0,83414.0,83414.0,83414.0,83414.0
mean,3.499149,2018.960906,18221.333098,22.588153,298.818229,2.07398,29.545844,71.541156,101.087,148.117606,13.046407,161.164013,266.368499,0.112751
std,1.688132,1.359378,11666.420279,48.840757,390.379696,2.661652,88.415228,194.473946,214.547898,415.015279,42.172475,412.831763,447.96442,0.31629
min,1.0,2015.0,20.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
25%,2.0,2018.0,7898.5,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,5.0,37.0,0.0
50%,3.0,2019.0,16299.0,4.0,62.0,1.0,2.0,0.0,13.0,15.0,0.0,37.0,110.0,0.0
75%,5.0,2020.0,28974.0,14.0,497.0,2.0,23.0,18.0,94.0,112.0,1.0,135.0,325.0,0.0
max,7.0,2021.0,38125.0,195.0,1653.0,129.0,902.0,1324.0,1348.0,4266.0,371.0,4266.0,4266.0,1.0


In [51]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 83414 entries, 0 to 83413
Data columns (total 33 columns):
 #   Column                    Non-Null Count  Dtype  
---  ------                    --------------  -----  
 0   dt                        83414 non-null  object 
 1   weekday                   83414 non-null  int64  
 2   year                      83414 non-null  int64  
 3   id_driver                 83414 non-null  int64  
 4   id_carrier_number         83414 non-null  object 
 5   dim_carrier_type          83414 non-null  object 
 6   dim_carrier_company_name  83367 non-null  object 
 7   home_base_city            83370 non-null  object 
 8   home_base_state           83370 non-null  object 
 9   carrier_trucks            83414 non-null  object 
 10  num_trucks                83344 non-null  float64
 11  interested_in_drayage     83414 non-null  object 
 12  port_qualified            83414 non-null  object 
 13  signup_source             83414 non-null  object 
 14  ts_sig

# 1. Generate Labels

In [52]:
# converts date from csv to a python datetime object making it easier to work with
df['most_recent_load_date'] = pd.to_datetime(df['most_recent_load_date'], format='%Y-%m-%d')

In [53]:
df['most_recent_load_date'].head()

0   2018-08-27
1   2019-09-05
2   2019-10-17
3   2019-12-28
4   2021-02-17
Name: most_recent_load_date, dtype: datetime64[ns]

In [54]:
loads75 = df.loads.quantile(0.75) # finds 75th percentile of loads
most_recent_load_date75 = df.most_recent_load_date.quantile(0.75) # finds 75th percentile of most recent load date

print(loads75)
print(most_recent_load_date75)
# Manual Check
# sorted_dts = sorted(list(df.most_recent_load_date))
# quartile_estimate_index = int(len(sorted_dts)*0.75)
# print("SORTED INDEX", sorted_dts[quartile_estimate_index])

2.0
2021-02-14 00:00:00


In [55]:
new_labels = {"label": {}}
for index, row in df.iterrows(): # changes the labels in the label columns
    # checks if the load and most recent load date are in the 75th percentile
    if row["loads"] >= loads75 and row["most_recent_load_date"] >= most_recent_load_date75:
        df.at[index, "label"] = 1
    else:
        df.at[index, "label"] = 0

In [56]:
df["label"].describe()

count    83414.000000
mean         0.151389
std          0.358430
min          0.000000
25%          0.000000
50%          0.000000
75%          0.000000
max          1.000000
Name: label, dtype: float64

# 2. Drop some columns

In [57]:
df = df.drop(columns=["loads", "most_recent_load_date"])

In [58]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 83414 entries, 0 to 83413
Data columns (total 31 columns):
 #   Column                    Non-Null Count  Dtype  
---  ------                    --------------  -----  
 0   dt                        83414 non-null  object 
 1   weekday                   83414 non-null  int64  
 2   year                      83414 non-null  int64  
 3   id_driver                 83414 non-null  int64  
 4   id_carrier_number         83414 non-null  object 
 5   dim_carrier_type          83414 non-null  object 
 6   dim_carrier_company_name  83367 non-null  object 
 7   home_base_city            83370 non-null  object 
 8   home_base_state           83370 non-null  object 
 9   carrier_trucks            83414 non-null  object 
 10  num_trucks                83344 non-null  float64
 11  interested_in_drayage     83414 non-null  object 
 12  port_qualified            83414 non-null  object 
 13  signup_source             83414 non-null  object 
 14  ts_sig

# 3. Basic Statistics

In [59]:
corr_matrix = df.corr()
for col_name in (list(df.columns.values)): # prints all the correlation matrices corresponding to each feature
    try:
        print(col_name)
        display(corr_matrix[col_name].sort_values(ascending=False))
        print('---------------------------------------------------------------------')
    except:
        print("{} is not of type integer".format(col_name))
        print('---------------------------------------------------------------------')

dt
dt is not of type integer
---------------------------------------------------------------------
weekday


weekday                    1.000000
year                       0.055610
id_driver                  0.053393
brokerage_loads_atlas      0.008955
label                      0.007623
marketplace_loads_otr      0.004377
num_trucks                 0.001470
marketplace_loads          0.001357
marketplace_loads_atlas   -0.000493
total_loads               -0.017164
brokerage_loads           -0.020757
brokerage_loads_otr       -0.021558
days_signup_to_approval   -0.053005
Name: weekday, dtype: float64

---------------------------------------------------------------------
year


year                       1.000000
id_driver                  0.784479
label                      0.394704
marketplace_loads_atlas    0.290073
brokerage_loads_atlas      0.274956
marketplace_loads          0.273890
total_loads                0.117035
num_trucks                 0.108146
weekday                    0.055610
marketplace_loads_otr      0.026589
brokerage_loads           -0.027475
brokerage_loads_otr       -0.055271
days_signup_to_approval   -0.644627
Name: year, dtype: float64

---------------------------------------------------------------------
id_driver


id_driver                  1.000000
year                       0.784479
label                      0.325040
brokerage_loads_atlas      0.280231
marketplace_loads_atlas    0.253327
marketplace_loads          0.191855
num_trucks                 0.143885
weekday                    0.053393
total_loads               -0.073022
marketplace_loads_otr     -0.091652
brokerage_loads           -0.192208
brokerage_loads_otr       -0.219673
days_signup_to_approval   -0.796913
Name: id_driver, dtype: float64

---------------------------------------------------------------------
id_carrier_number
id_carrier_number is not of type integer
---------------------------------------------------------------------
dim_carrier_type
dim_carrier_type is not of type integer
---------------------------------------------------------------------
dim_carrier_company_name
dim_carrier_company_name is not of type integer
---------------------------------------------------------------------
home_base_city
home_base_city is not of type integer
---------------------------------------------------------------------
home_base_state
home_base_state is not of type integer
---------------------------------------------------------------------
carrier_trucks
carrier_trucks is not of type integer
---------------------------------------------------------------------
num_trucks


num_trucks                 1.000000
id_driver                  0.143885
label                      0.115391
year                       0.108146
days_signup_to_approval    0.051251
weekday                    0.001470
brokerage_loads_otr       -0.050238
brokerage_loads           -0.062842
total_loads               -0.114414
marketplace_loads_otr     -0.120234
brokerage_loads_atlas     -0.120770
marketplace_loads_atlas   -0.150238
marketplace_loads         -0.185737
Name: num_trucks, dtype: float64

---------------------------------------------------------------------
interested_in_drayage
interested_in_drayage is not of type integer
---------------------------------------------------------------------
port_qualified
port_qualified is not of type integer
---------------------------------------------------------------------
signup_source
signup_source is not of type integer
---------------------------------------------------------------------
ts_signup
ts_signup is not of type integer
---------------------------------------------------------------------
ts_first_approved
ts_first_approved is not of type integer
---------------------------------------------------------------------
days_signup_to_approval


days_signup_to_approval    1.000000
brokerage_loads_otr        0.133629
brokerage_loads            0.109760
marketplace_loads_otr      0.085119
num_trucks                 0.051251
total_loads               -0.004422
weekday                   -0.053005
marketplace_loads         -0.159411
brokerage_loads_atlas     -0.188662
marketplace_loads_atlas   -0.213705
label                     -0.253444
year                      -0.644627
id_driver                 -0.796913
Name: days_signup_to_approval, dtype: float64

---------------------------------------------------------------------
driver_with_twic
driver_with_twic is not of type integer
---------------------------------------------------------------------
dim_preferred_lanes
dim_preferred_lanes is not of type integer
---------------------------------------------------------------------
first_load_date
first_load_date is not of type integer
---------------------------------------------------------------------
load_day
load_day is not of type integer
---------------------------------------------------------------------
marketplace_loads_otr


marketplace_loads_otr      1.000000
marketplace_loads          0.422470
total_loads                0.180977
days_signup_to_approval    0.085119
year                       0.026589
marketplace_loads_atlas    0.011441
weekday                    0.004377
brokerage_loads_otr       -0.013174
brokerage_loads           -0.019866
label                     -0.030935
brokerage_loads_atlas     -0.064830
id_driver                 -0.091652
num_trucks                -0.120234
Name: marketplace_loads_otr, dtype: float64

---------------------------------------------------------------------
marketplace_loads_atlas


marketplace_loads_atlas    1.000000
marketplace_loads          0.911151
brokerage_loads_atlas      0.409536
label                      0.404792
total_loads                0.360903
year                       0.290073
id_driver                  0.253327
marketplace_loads_otr      0.011441
weekday                   -0.000493
brokerage_loads           -0.078238
brokerage_loads_otr       -0.119442
num_trucks                -0.150238
days_signup_to_approval   -0.213705
Name: marketplace_loads_atlas, dtype: float64

---------------------------------------------------------------------
marketplace_loads


marketplace_loads          1.000000
marketplace_loads_atlas    0.911151
marketplace_loads_otr      0.422470
total_loads                0.401716
label                      0.354169
brokerage_loads_atlas      0.344502
year                       0.273890
id_driver                  0.191855
weekday                    0.001357
brokerage_loads           -0.079104
brokerage_loads_otr       -0.113695
days_signup_to_approval   -0.159411
num_trucks                -0.185737
Name: marketplace_loads, dtype: float64

---------------------------------------------------------------------
brokerage_loads_otr


brokerage_loads_otr        1.000000
brokerage_loads            0.994824
total_loads                0.859125
days_signup_to_approval    0.133629
label                      0.105394
marketplace_loads_otr     -0.013174
weekday                   -0.021558
num_trucks                -0.050238
year                      -0.055271
brokerage_loads_atlas     -0.102448
marketplace_loads         -0.113695
marketplace_loads_atlas   -0.119442
id_driver                 -0.219673
Name: brokerage_loads_otr, dtype: float64

---------------------------------------------------------------------
brokerage_loads_atlas


brokerage_loads_atlas      1.000000
marketplace_loads_atlas    0.409536
marketplace_loads          0.344502
label                      0.320236
id_driver                  0.280231
year                       0.274956
total_loads                0.161382
weekday                    0.008955
brokerage_loads           -0.000836
marketplace_loads_otr     -0.064830
brokerage_loads_otr       -0.102448
num_trucks                -0.120770
days_signup_to_approval   -0.188662
Name: brokerage_loads_atlas, dtype: float64

---------------------------------------------------------------------
brokerage_loads


brokerage_loads            1.000000
brokerage_loads_otr        0.994824
total_loads                0.880155
label                      0.138665
days_signup_to_approval    0.109760
brokerage_loads_atlas     -0.000836
marketplace_loads_otr     -0.019866
weekday                   -0.020757
year                      -0.027475
num_trucks                -0.062842
marketplace_loads_atlas   -0.078238
marketplace_loads         -0.079104
id_driver                 -0.192208
Name: brokerage_loads, dtype: float64

---------------------------------------------------------------------
total_loads


total_loads                1.000000
brokerage_loads            0.880155
brokerage_loads_otr        0.859125
marketplace_loads          0.401716
marketplace_loads_atlas    0.360903
label                      0.306941
marketplace_loads_otr      0.180977
brokerage_loads_atlas      0.161382
year                       0.117035
days_signup_to_approval   -0.004422
weekday                   -0.017164
id_driver                 -0.073022
num_trucks                -0.114414
Name: total_loads, dtype: float64

---------------------------------------------------------------------
date
date is not of type integer
---------------------------------------------------------------------
recent_date
recent_date is not of type integer
---------------------------------------------------------------------
label


label                      1.000000
marketplace_loads_atlas    0.404792
year                       0.394704
marketplace_loads          0.354169
id_driver                  0.325040
brokerage_loads_atlas      0.320236
total_loads                0.306941
brokerage_loads            0.138665
num_trucks                 0.115391
brokerage_loads_otr        0.105394
weekday                    0.007623
marketplace_loads_otr     -0.030935
days_signup_to_approval   -0.253444
Name: label, dtype: float64

---------------------------------------------------------------------


 Also year and TODO_FIND_COLUMN_NAME_2 and year are highly correlated and have a similar impact on label, so we could drop one? 

Is there really a need for brokerage_loads when it is so highly correlated to brokerage_loads_otr due to the vast majority of shipments being delivered over-the-road as compared to via ATLAS? 

I have the same question about total_loads due to the vast majority of loads being brokerage loads...

What's the point of having both year and date?

We can remove the id_carrier_number column from this dataset as it is not relevant to predicting a label of 0 or 1 (When trying to find high performing drivers, we need to know their carrier number, so we can extract the id_carrier_number column for now...)

We could one-hot-encode sign-up source and see its effect on labels.

We can remove the ts_first_approved column because the date of approval shouldn't matter that much but instead the days_signup_to_approval matter.

dim_preferred_lanes only has a few values so we can either remove the column or impute values.

Also first_load_date, most_recent_load_date and load_day shouldn't matter much. Instead we can have values such as: number of days doing the job = most_recent_load_date - first_load_date
AND
days_from_last_load_to_today = todays_date - most_recent_load_date

There are also a couple other features we need to impute.

Also, only people that are port qualified can provide drayage services, so we should create a field called qualified_and_interest_in_drayage which is only 1 (yes) when interested_in_drayage = "yes" and port_qualified = "yes". We can also cross these features...

# 4. Data Feature Extraction Plan and Pipeline

In [60]:
df["location"] = list(zip(df["home_base_city"], df["home_base_state"]))# feature cross to get (city, state) tuple
# feature cross for interested in drayage and port qualified
df["drayage_interested_port_qualified"] = list(zip(df["interested_in_drayage"], df["port_qualified"]))
display(df["location"])
display(df["drayage_interested_port_qualified"])

0        (City of Industry, CA)
1             (Los Angeles, CA)
2                 (Ontario, CA)
3              (Long Beach, CA)
4          (South El Monte, CA)
                  ...          
83409             (Seattle, WA)
83410         (Los Angeles, CA)
83411           (La Puente, CA)
83412             (Phoenix, AZ)
83413             (Compton, CA)
Name: location, Length: 83414, dtype: object

0                  (yes, no)
1        (not specified, no)
2                  (yes, no)
3                 (yes, yes)
4        (not specified, no)
                ...         
83409    (not specified, no)
83410    (not specified, no)
83411    (not specified, no)
83412    (not specified, no)
83413             (yes, yes)
Name: drayage_interested_port_qualified, Length: 83414, dtype: object

In [61]:
id_carrier_number_col = np.array(df["id_carrier_number"]) # extract id_carrier_number column
id_driver_number_col = np.array(df["id_driver"]) # extract id_driver column

drop_cols = ["id_carrier_number", "id_driver", "weekday", "home_base_city",
             "home_base_state", "interested_in_drayage", "port_qualified"]

df = df.drop(columns = drop_cols) # drop columns that don't affect the label value by much
print(df.dim_carrier_company_name.nunique())

2493


In [62]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 83414 entries, 0 to 83413
Data columns (total 26 columns):
 #   Column                             Non-Null Count  Dtype  
---  ------                             --------------  -----  
 0   dt                                 83414 non-null  object 
 1   year                               83414 non-null  int64  
 2   dim_carrier_type                   83414 non-null  object 
 3   dim_carrier_company_name           83367 non-null  object 
 4   carrier_trucks                     83414 non-null  object 
 5   num_trucks                         83344 non-null  float64
 6   signup_source                      83414 non-null  object 
 7   ts_signup                          83414 non-null  object 
 8   ts_first_approved                  71128 non-null  object 
 9   days_signup_to_approval            71128 non-null  float64
 10  driver_with_twic                   83414 non-null  object 
 11  dim_preferred_lanes                3413 non-null   obj

In [None]:
imputer = IterativeImputer()
categorical_features_one_hot = ["dim_carrier_type", "carrier_trucks", "location", 
                                "signup_source", "driver_with_twic", "drayage_interested_port_qualified"]

num_pipeline = Pipeline([
        ('imputer', imputer),
        ('std_scaler', StandardScaler()),
    ])

full_pipeline = ColumnTransformer([
        ("num", num_pipeline, numerical_features),
        ("cat", OneHotEncoder(), categorical_features_one_hot),
    ])

# 6. PCA

In [None]:
max_len = len(df.columns)
arr = []
for i in range(1, max_len+1):
    pca = PCA(n_components=len(df.columns))
    pca.fit(df)
    arr.append(pca)
    #pca.fit_transform(df)
# need to transform test data after finishing data pipelining

# 7. Ensemble AKA Robert's BS pls mercy

In [None]:
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score

# assuming we have X_train,X_test,y_train,y_test at this time
# I first run Random Forest using random hard coded settings to get a baseline
rf = RandomForestRegressor(n_estimators=80,max_depth=7,max_features=3)
rf.fit(X_train,y_train)
y_pred = rf.predict(X_test)
test_score = r2_score(y_test,y_pred)
test_score

In [None]:
from sklearn.model_selection import RandomizedSearchCV
import numpy as np
# I then use RandomizedSearchCV to find the optimal hyperparameters
n_estimators = [int(x) for x in np.linspace(start = 100, stop = 1000, num = 10)]
max_features = ['auto', 'sqrt']
max_depth = [int(x) for x in np.linspace(5, 100, num = 10)]
max_depth.append(None)
min_samples_split = [2, 5, 7, 10, 20]
min_samples_leaf = [1, 2, 5, 10]
bootstrap = [True, False]
random_grid = {'n_estimators': n_estimators,
               'max_features': max_features,
               'max_depth': max_depth,
               'min_samples_split': min_samples_split,
               'min_samples_leaf': min_samples_leaf,
               'bootstrap': bootstrap}

rf_random = RandomizedSearchCV(estimator = rf, param_distributions = random_grid, n_iter = 100, cv = 3, verbose=2, random_state=42, n_jobs = -1)
rf_random.fit(X_train, y_train)
rf_random.best_params_

In [None]:
# I then output the r2 score again as a sanity check to verify that my RanomdizedSearchCV actually did find the best settings
rf = RandomForestRegressor(n_estimators=700,max_depth=47,max_features='auto',min_samples_split=2,min_samples_leaf=2,bootstrap=True)
rf.fit(X_train,y_train)
y_pred = rf.predict(X_test)
from sklearn.metrics import r2_score
test_score = r2_score(y_test,y_pred)
test_score

In [None]:
# we then use the hyperparameters we found from the RandomizedSearchCV to do a second more thorough check around that range
from sklearn.model_selection import GridSearchCV
param_grid = {
    'bootstrap': [True],
    'max_depth': [40, 45, 50, 55, 60],
    'max_features': [2, 5, 7, 10, 12],
    'min_samples_leaf': [2, 3, 4, 5, 6],
    'min_samples_split': [2, 3, 4, 5, 6],
    'n_estimators': [100, 200, 500, 700, 1000]
}
# Create a based model
rf = RandomForestRegressor()
# Instantiate the grid search model
grid_search = GridSearchCV(estimator = rf, param_grid = param_grid, 
                          cv = 3, n_jobs = -1, verbose = 2)
grid_search.fit(X_train, y_train)
grid_search.best_params_

In [None]:
# we then output the results using the optimal hyperparameters to check that our model has improved
rf = RandomForestRegressor(n_estimators=500,max_depth=40,max_features=7,min_samples_split=4,min_samples_leaf=2,bootstrap=True)
rf.fit(X_train,y_train)
y_pred = rf.predict(X_test)
from sklearn.metrics import r2_score
test_score = r2_score(y_test,y_pred)
test_score

In [None]:
# AdaBoost using the same settings
from sklearn.ensemble import AdaBoostClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn import metrics

dt = DecisionTreeClassifier(max_depth=3)
ab = AdaBoostClassifier(base_estimator=dt,learning_rate=1,n_estimators=50)
ab.fit(X_train,y_train)
y_pred = ab.predict(X_test)
test_score = r2_score(y_test,y_pred)
accuracyResult = metrics.accuracy_score(y_test,y_pred)
print("R2 Score: ",test_score)
print("Accuracy Score: ",accuracyResult)

# 8. Neural Network Classifier

In [None]:
# # Define a transform to normalize the data
# transform = transforms.Compose([transforms.ToTensor(),
#                                 transforms.Normalize((0.5,), (0.5,)),
#                               ])
# Download and load the training data
NN_X_train = torch.tensor(df.drop(["label"], axis=1).values)
NN_y_train = torch.tensor(df["label"].values)
trainset = TensorDataset(NN_X_train, NN_y_train)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

In [None]:
NUM_FEATURES = 20
NUM_HIDDEN1_NODES = 400
NUM_EPOCHS = 30

model = nn.Sequential(nn.Linear(NUM_FEATURES, NUM_HIDDEN1_NODES),
                      nn.Sigmoid(),
                      nn.Linear(NUM_HIDDEN1_NODES, 1)
                     )

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.003)

for e in range(NUM_EPOCHS):
    running_loss = 0
    for data, labels in trainloader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Training loss: {running_loss/len(trainloader)}")
    
# class NeuralNet(nn.Module):
#     def __init__(self):
#         super(NeuralNet, self).__init__()
        
#         # Inputs to hidden layer linear transformation
#         self.hidden1 = nn.Linear(NUM_FEATURES, NUM_HIDDEN1_NODES)
#         self.output = nn.Linear(NUM_HIDDEN1_NODES, 1)
        
#         self.sigmoid = nn.Sigmoid()
        
#     def forward(self, x):
#         x = self.hidden1(x)
#         x = self.sigmoid(x)
#         x = self.output(x)
#         x = self.softmax(x)
#         return x

In [None]:
# model = Network()
# model