This idea of this notebook is to create a first (simple) DL model using all the same features used in the multinomial logit model.

Much inspiration was derived from https://towardsdatascience.com/use-machine-learning-to-predict-horse-racing-4f1111fb6ced.

In [21]:
import math
from importlib import reload
import deeplearninglib
reload(deeplearninglib)
from deeplearninglib import *

import wandb

# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
print(f"Using {device} device")

Using cpu device


In [22]:
# select model to train

model_inventory = {'mktprob': {'X_columns': ["mkt_prob"],
                               'learning_rate': 10e-1,
                               'epochs': 50,
                               'vacant_stall_indicator': False,
                               'bias': True,
                               'model_architecture': LinSig},
                   'mktprob_soft': {'X_columns': ["mkt_prob"],
                               'learning_rate': 10e-1,
                               'epochs': 50,
                               'vacant_stall_indicator': False,
                               'bias': True,
                               'model_architecture': LinSoft},
                    'mktprob_MLR': {'X_columns': ["mkt_prob"],
                               'learning_rate': 10e-1,
                               'epochs': 50,
                               'vacant_stall_indicator': False,
                               'bias': True,
                               'model_architecture': MLR},
                   'AlunOwen_v0': {'X_columns': ["age", "sire_sr", "dam_sr", "trainer_sr", "daysLTO", "position1_1", "position1_2", "position1_3", "position1_4", "position2_1", "position2_2", "position2_3", "position2_4", "position3_1", "position3_2", "position3_3", "position3_4", "entire", "gelding", "blinkers", "visor", "cheekpieces", "tonguetie"],
                                   'learning_rate': 10e-3,
                                   'epochs': 100,
                                   'vacant_stall_indicator': False,
                                   'bias': True,
                                   'model_architecture': LinSig},
                   'AlunOwen_v1': {'X_columns': ["age", "trainer_sr", "daysLTO", "position1_1", "position1_2", "position1_3", "position1_4", "position2_1", "position2_2", "position2_3", "position2_4", "position3_1", "position3_2", "position3_3", "position3_4", "entire", "gelding", "blinkers", "cheekpieces", "tonguetie"],
                                   'learning_rate': 10e-3,
                                   'epochs': 100,
                                   'vacant_stall_indicator': False,
                                   'bias': True,
                                   'model_architecture': LinSig},
                   'AlunOwen_v2': {'X_columns': ["age", "trainer_sr", "daysLTO", "position1_1", "position1_2", "position1_3", "position1_4", "position2_1", "position2_2", "position2_3", "position2_4", "position3_1", "position3_2", "position3_3", "position3_4", "entire", "gelding", "blinkers", "cheekpieces", "tonguetie"],
                                   'learning_rate': 10e-3,
                                   'epochs': 100,
                                   'vacant_stall_indicator': False,
                                   'bias': True,
                                   'model_architecture': LinDropReluLinSoft},
                   'AlunOwen_v3': {'X_columns': ["age", "trainer_sr", "daysLTO", "position1_1", "position1_2", "position1_3", "position1_4", "position2_1", "position2_2", "position2_3", "position2_4", "position3_1", "position3_2", "position3_3", "position3_4", "entire", "gelding", "blinkers", "cheekpieces", "tonguetie"],
                                   'learning_rate': 10e-3,
                                   'epochs': 100,
                                   'vacant_stall_indicator': False,
                                   'bias': False,
                                   'model_architecture': MLR}
                               }

model = 'AlunOwen_v3'
X_columns = model_inventory[model]['X_columns']
learning_rate = model_inventory[model]['learning_rate']
epochs = model_inventory[model]['epochs']
vacant_stall_indicator = model_inventory[model]['vacant_stall_indicator']
bias = model_inventory[model]['bias']
model_architecture = model_inventory[model]['model_architecture']

In [23]:
# read in data

y_columns = ["win"]

train_data_fn = "data\\runners_train.csv"
test_data_fn = "data\\runners_test.csv"

train_data = RacesDataset(train_data_fn, X_columns, y_columns, vacant_stall_indicator=vacant_stall_indicator)
test_data = RacesDataset(test_data_fn, X_columns, y_columns, vacant_stall_indicator=vacant_stall_indicator, scalar=train_data.scalar)

train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

In [24]:
pd.options.display.max_columns = 1000 # was 20
train_data.races.iloc[:, train_data.X_columns].head()

Unnamed: 0_level_0,age,age,age,age,age,age,age,age,age,age,age,age,age,age,age,age,blinkers,blinkers,blinkers,blinkers,blinkers,blinkers,blinkers,blinkers,blinkers,blinkers,blinkers,blinkers,blinkers,blinkers,blinkers,blinkers,cheekpieces,cheekpieces,cheekpieces,cheekpieces,cheekpieces,cheekpieces,cheekpieces,cheekpieces,cheekpieces,cheekpieces,cheekpieces,cheekpieces,cheekpieces,cheekpieces,cheekpieces,cheekpieces,daysLTO,daysLTO,daysLTO,daysLTO,daysLTO,daysLTO,daysLTO,daysLTO,daysLTO,daysLTO,daysLTO,daysLTO,daysLTO,daysLTO,daysLTO,daysLTO,entire,entire,entire,entire,entire,entire,entire,entire,entire,entire,entire,entire,entire,entire,entire,entire,gelding,gelding,gelding,gelding,gelding,gelding,gelding,gelding,gelding,gelding,gelding,gelding,gelding,gelding,gelding,gelding,position1_1,position1_1,position1_1,position1_1,position1_1,position1_1,position1_1,position1_1,position1_1,position1_1,position1_1,position1_1,position1_1,position1_1,position1_1,position1_1,position1_2,position1_2,position1_2,position1_2,position1_2,position1_2,position1_2,position1_2,position1_2,position1_2,position1_2,position1_2,position1_2,position1_2,position1_2,position1_2,position1_3,position1_3,position1_3,position1_3,position1_3,position1_3,position1_3,position1_3,position1_3,position1_3,position1_3,position1_3,position1_3,position1_3,position1_3,position1_3,position1_4,position1_4,position1_4,position1_4,position1_4,position1_4,position1_4,position1_4,position1_4,position1_4,position1_4,position1_4,position1_4,position1_4,position1_4,position1_4,position2_1,position2_1,position2_1,position2_1,position2_1,position2_1,position2_1,position2_1,position2_1,position2_1,position2_1,position2_1,position2_1,position2_1,position2_1,position2_1,position2_2,position2_2,position2_2,position2_2,position2_2,position2_2,position2_2,position2_2,position2_2,position2_2,position2_2,position2_2,position2_2,position2_2,position2_2,position2_2,position2_3,position2_3,position2_3,position2_3,position2_3,position2_3,position2_3,position2_3,position2_3,position2_3,position2_3,position2_3,position2_3,position2_3,position2_3,position2_3,position2_4,position2_4,position2_4,position2_4,position2_4,position2_4,position2_4,position2_4,position2_4,position2_4,position2_4,position2_4,position2_4,position2_4,position2_4,position2_4,position3_1,position3_1,position3_1,position3_1,position3_1,position3_1,position3_1,position3_1,position3_1,position3_1,position3_1,position3_1,position3_1,position3_1,position3_1,position3_1,position3_2,position3_2,position3_2,position3_2,position3_2,position3_2,position3_2,position3_2,position3_2,position3_2,position3_2,position3_2,position3_2,position3_2,position3_2,position3_2,position3_3,position3_3,position3_3,position3_3,position3_3,position3_3,position3_3,position3_3,position3_3,position3_3,position3_3,position3_3,position3_3,position3_3,position3_3,position3_3,position3_4,position3_4,position3_4,position3_4,position3_4,position3_4,position3_4,position3_4,position3_4,position3_4,position3_4,position3_4,position3_4,position3_4,position3_4,position3_4,tonguetie,tonguetie,tonguetie,tonguetie,tonguetie,tonguetie,tonguetie,tonguetie,tonguetie,tonguetie,tonguetie,tonguetie,tonguetie,tonguetie,tonguetie,tonguetie,trainer_sr,trainer_sr,trainer_sr,trainer_sr,trainer_sr,trainer_sr,trainer_sr,trainer_sr,trainer_sr,trainer_sr,trainer_sr,trainer_sr,trainer_sr,trainer_sr,trainer_sr,trainer_sr
stall_number,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16
race_id,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2,Unnamed: 25_level_2,Unnamed: 26_level_2,Unnamed: 27_level_2,Unnamed: 28_level_2,Unnamed: 29_level_2,Unnamed: 30_level_2,Unnamed: 31_level_2,Unnamed: 32_level_2,Unnamed: 33_level_2,Unnamed: 34_level_2,Unnamed: 35_level_2,Unnamed: 36_level_2,Unnamed: 37_level_2,Unnamed: 38_level_2,Unnamed: 39_level_2,Unnamed: 40_level_2,Unnamed: 41_level_2,Unnamed: 42_level_2,Unnamed: 43_level_2,Unnamed: 44_level_2,Unnamed: 45_level_2,Unnamed: 46_level_2,Unnamed: 47_level_2,Unnamed: 48_level_2,Unnamed: 49_level_2,Unnamed: 50_level_2,Unnamed: 51_level_2,Unnamed: 52_level_2,Unnamed: 53_level_2,Unnamed: 54_level_2,Unnamed: 55_level_2,Unnamed: 56_level_2,Unnamed: 57_level_2,Unnamed: 58_level_2,Unnamed: 59_level_2,Unnamed: 60_level_2,Unnamed: 61_level_2,Unnamed: 62_level_2,Unnamed: 63_level_2,Unnamed: 64_level_2,Unnamed: 65_level_2,Unnamed: 66_level_2,Unnamed: 67_level_2,Unnamed: 68_level_2,Unnamed: 69_level_2,Unnamed: 70_level_2,Unnamed: 71_level_2,Unnamed: 72_level_2,Unnamed: 73_level_2,Unnamed: 74_level_2,Unnamed: 75_level_2,Unnamed: 76_level_2,Unnamed: 77_level_2,Unnamed: 78_level_2,Unnamed: 79_level_2,Unnamed: 80_level_2,Unnamed: 81_level_2,Unnamed: 82_level_2,Unnamed: 83_level_2,Unnamed: 84_level_2,Unnamed: 85_level_2,Unnamed: 86_level_2,Unnamed: 87_level_2,Unnamed: 88_level_2,Unnamed: 89_level_2,Unnamed: 90_level_2,Unnamed: 91_level_2,Unnamed: 92_level_2,Unnamed: 93_level_2,Unnamed: 94_level_2,Unnamed: 95_level_2,Unnamed: 96_level_2,Unnamed: 97_level_2,Unnamed: 98_level_2,Unnamed: 99_level_2,Unnamed: 100_level_2,Unnamed: 101_level_2,Unnamed: 102_level_2,Unnamed: 103_level_2,Unnamed: 104_level_2,Unnamed: 105_level_2,Unnamed: 106_level_2,Unnamed: 107_level_2,Unnamed: 108_level_2,Unnamed: 109_level_2,Unnamed: 110_level_2,Unnamed: 111_level_2,Unnamed: 112_level_2,Unnamed: 113_level_2,Unnamed: 114_level_2,Unnamed: 115_level_2,Unnamed: 116_level_2,Unnamed: 117_level_2,Unnamed: 118_level_2,Unnamed: 119_level_2,Unnamed: 120_level_2,Unnamed: 121_level_2,Unnamed: 122_level_2,Unnamed: 123_level_2,Unnamed: 124_level_2,Unnamed: 125_level_2,Unnamed: 126_level_2,Unnamed: 127_level_2,Unnamed: 128_level_2,Unnamed: 129_level_2,Unnamed: 130_level_2,Unnamed: 131_level_2,Unnamed: 132_level_2,Unnamed: 133_level_2,Unnamed: 134_level_2,Unnamed: 135_level_2,Unnamed: 136_level_2,Unnamed: 137_level_2,Unnamed: 138_level_2,Unnamed: 139_level_2,Unnamed: 140_level_2,Unnamed: 141_level_2,Unnamed: 142_level_2,Unnamed: 143_level_2,Unnamed: 144_level_2,Unnamed: 145_level_2,Unnamed: 146_level_2,Unnamed: 147_level_2,Unnamed: 148_level_2,Unnamed: 149_level_2,Unnamed: 150_level_2,Unnamed: 151_level_2,Unnamed: 152_level_2,Unnamed: 153_level_2,Unnamed: 154_level_2,Unnamed: 155_level_2,Unnamed: 156_level_2,Unnamed: 157_level_2,Unnamed: 158_level_2,Unnamed: 159_level_2,Unnamed: 160_level_2,Unnamed: 161_level_2,Unnamed: 162_level_2,Unnamed: 163_level_2,Unnamed: 164_level_2,Unnamed: 165_level_2,Unnamed: 166_level_2,Unnamed: 167_level_2,Unnamed: 168_level_2,Unnamed: 169_level_2,Unnamed: 170_level_2,Unnamed: 171_level_2,Unnamed: 172_level_2,Unnamed: 173_level_2,Unnamed: 174_level_2,Unnamed: 175_level_2,Unnamed: 176_level_2,Unnamed: 177_level_2,Unnamed: 178_level_2,Unnamed: 179_level_2,Unnamed: 180_level_2,Unnamed: 181_level_2,Unnamed: 182_level_2,Unnamed: 183_level_2,Unnamed: 184_level_2,Unnamed: 185_level_2,Unnamed: 186_level_2,Unnamed: 187_level_2,Unnamed: 188_level_2,Unnamed: 189_level_2,Unnamed: 190_level_2,Unnamed: 191_level_2,Unnamed: 192_level_2,Unnamed: 193_level_2,Unnamed: 194_level_2,Unnamed: 195_level_2,Unnamed: 196_level_2,Unnamed: 197_level_2,Unnamed: 198_level_2,Unnamed: 199_level_2,Unnamed: 200_level_2,Unnamed: 201_level_2,Unnamed: 202_level_2,Unnamed: 203_level_2,Unnamed: 204_level_2,Unnamed: 205_level_2,Unnamed: 206_level_2,Unnamed: 207_level_2,Unnamed: 208_level_2,Unnamed: 209_level_2,Unnamed: 210_level_2,Unnamed: 211_level_2,Unnamed: 212_level_2,Unnamed: 213_level_2,Unnamed: 214_level_2,Unnamed: 215_level_2,Unnamed: 216_level_2,Unnamed: 217_level_2,Unnamed: 218_level_2,Unnamed: 219_level_2,Unnamed: 220_level_2,Unnamed: 221_level_2,Unnamed: 222_level_2,Unnamed: 223_level_2,Unnamed: 224_level_2,Unnamed: 225_level_2,Unnamed: 226_level_2,Unnamed: 227_level_2,Unnamed: 228_level_2,Unnamed: 229_level_2,Unnamed: 230_level_2,Unnamed: 231_level_2,Unnamed: 232_level_2,Unnamed: 233_level_2,Unnamed: 234_level_2,Unnamed: 235_level_2,Unnamed: 236_level_2,Unnamed: 237_level_2,Unnamed: 238_level_2,Unnamed: 239_level_2,Unnamed: 240_level_2,Unnamed: 241_level_2,Unnamed: 242_level_2,Unnamed: 243_level_2,Unnamed: 244_level_2,Unnamed: 245_level_2,Unnamed: 246_level_2,Unnamed: 247_level_2,Unnamed: 248_level_2,Unnamed: 249_level_2,Unnamed: 250_level_2,Unnamed: 251_level_2,Unnamed: 252_level_2,Unnamed: 253_level_2,Unnamed: 254_level_2,Unnamed: 255_level_2,Unnamed: 256_level_2,Unnamed: 257_level_2,Unnamed: 258_level_2,Unnamed: 259_level_2,Unnamed: 260_level_2,Unnamed: 261_level_2,Unnamed: 262_level_2,Unnamed: 263_level_2,Unnamed: 264_level_2,Unnamed: 265_level_2,Unnamed: 266_level_2,Unnamed: 267_level_2,Unnamed: 268_level_2,Unnamed: 269_level_2,Unnamed: 270_level_2,Unnamed: 271_level_2,Unnamed: 272_level_2,Unnamed: 273_level_2,Unnamed: 274_level_2,Unnamed: 275_level_2,Unnamed: 276_level_2,Unnamed: 277_level_2,Unnamed: 278_level_2,Unnamed: 279_level_2,Unnamed: 280_level_2,Unnamed: 281_level_2,Unnamed: 282_level_2,Unnamed: 283_level_2,Unnamed: 284_level_2,Unnamed: 285_level_2,Unnamed: 286_level_2,Unnamed: 287_level_2,Unnamed: 288_level_2,Unnamed: 289_level_2,Unnamed: 290_level_2,Unnamed: 291_level_2,Unnamed: 292_level_2,Unnamed: 293_level_2,Unnamed: 294_level_2,Unnamed: 295_level_2,Unnamed: 296_level_2,Unnamed: 297_level_2,Unnamed: 298_level_2,Unnamed: 299_level_2,Unnamed: 300_level_2,Unnamed: 301_level_2,Unnamed: 302_level_2,Unnamed: 303_level_2,Unnamed: 304_level_2,Unnamed: 305_level_2,Unnamed: 306_level_2,Unnamed: 307_level_2,Unnamed: 308_level_2,Unnamed: 309_level_2,Unnamed: 310_level_2,Unnamed: 311_level_2,Unnamed: 312_level_2,Unnamed: 313_level_2,Unnamed: 314_level_2,Unnamed: 315_level_2,Unnamed: 316_level_2,Unnamed: 317_level_2,Unnamed: 318_level_2,Unnamed: 319_level_2,Unnamed: 320_level_2
11504,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,365.0,365.0,365.0,365.0,365.0,365.0,365.0,365.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
11505,4.0,6.0,9.0,7.0,5.0,8.0,5.0,5.0,5.0,4.0,5.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,365.0,365.0,365.0,365.0,365.0,365.0,365.0,365.0,365.0,365.0,365.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
11506,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,365.0,365.0,365.0,365.0,365.0,365.0,365.0,365.0,365.0,365.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
11507,5.0,4.0,4.0,4.0,4.0,4.0,8.0,4.0,4.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,365.0,365.0,365.0,365.0,365.0,365.0,365.0,365.0,365.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
11508,7.0,4.0,5.0,4.0,5.0,4.0,5.0,5.0,4.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,365.0,365.0,365.0,365.0,365.0,365.0,365.0,365.0,365.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [25]:
# build the neural network

output_layer_nodes = train_data.y.shape[1]
input_layer_nodes = train_data.X.shape[1]

# torch.manual_seed(0)
net = model_architecture(input_layer_nodes, output_layer_nodes, bias=bias).to(device) # linear-relu-linear-softwax nn (1 hidden layer)
print(f"Model structure: {model}")

for name, param in net.named_parameters():
    print(f"Layer: {name} | Size: {param.size()} | Values : {param} \n")

Model structure: AlunOwen_v3
Layer: neural_network.0.weights | Size: torch.Size([20]) | Values : Parameter containing:
tensor([-0.3507, -0.3330,  0.1525,  0.3119,  0.2764, -0.0767,  0.4799,  0.2086,
         0.1461,  0.4840, -0.1382, -0.0034, -0.1200,  0.0239,  0.0616, -0.1162,
        -0.0818,  0.0359, -0.1207, -0.0234], requires_grad=True) 



In [26]:
# example to show how model is used from prediction

X = torch.rand(1, input_layer_nodes, device=device)
logits = net(X)
y_pred = logits.argmax(1)
print(f"Predicted class: {y_pred}")
print(X.shape)

Predicted class: tensor([6])
torch.Size([1, 320])


In [27]:
%env WANDB_NOTEBOOK_NAME 'C:\Users\gille\OneDrive\1-Projects\_Horse Racing 2H22\New Framework\3b_Deep Learning.ipynb'

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="horse-racing-project",
    
    # track hyperparameters and run metadata
    config={
    "device": device,
    "model": model,
    "X_columns": X_columns,
    "learning_rate": learning_rate,
    "epochs": epochs,
    "vacant_stall_indicator": vacant_stall_indicator,
    "bias": bias,
    "model_architecture": list(net.modules())
    }
)

env: WANDB_NOTEBOOK_NAME='C:\Users\gille\OneDrive\1-Projects\_Horse Racing 2H22\New Framework\3b_Deep Learning.ipynb'


In [28]:
# optimizing model parameters

# initialize the loss function
loss_fn = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, net, loss_fn, optimizer, device)
    (acc, loss) = test_loop(test_dataloader, net, loss_fn, device)
    wandb.log({"acc": acc, "loss": loss})
print("Done!")

Epoch 1
-------------------------------
loss: 2.772957  [   64/16620]
loss: 2.760281  [ 6464/16620]
loss: 2.811059  [12864/16620]
Test Error: 
 Accuracy: 15.9%, Avg loss: 2.753920 

Epoch 2
-------------------------------
loss: 2.769868  [   64/16620]
loss: 2.747210  [ 6464/16620]
loss: 2.799972  [12864/16620]
Test Error: 
 Accuracy: 18.0%, Avg loss: 2.741208 

Epoch 3
-------------------------------
loss: 2.764286  [   64/16620]
loss: 2.733308  [ 6464/16620]
loss: 2.787822  [12864/16620]
Test Error: 
 Accuracy: 19.1%, Avg loss: 2.728886 

Epoch 4
-------------------------------
loss: 2.756614  [   64/16620]
loss: 2.720457  [ 6464/16620]
loss: 2.775795  [12864/16620]
Test Error: 
 Accuracy: 19.9%, Avg loss: 2.717927 

Epoch 5
-------------------------------
loss: 2.748583  [   64/16620]
loss: 2.709217  [ 6464/16620]
loss: 2.764917  [12864/16620]
Test Error: 
 Accuracy: 20.4%, Avg loss: 2.708689 

Epoch 6
-------------------------------
loss: 2.741418  [   64/16620]
loss: 2.699215  [ 64

In [29]:
# finish the wandb run, necessary in notebooks
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
acc,▁▄▆▇▆▇▇▇▇▇▇▇██████▇██▇██████████████████
loss,█▆▄▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
acc,0.22897
loss,2.65105


In [30]:
for para_name, para_vals in net.named_parameters():
    np.savetxt(para_name + ".csv", para_vals.data.numpy(), fmt='%6.3f', delimiter=",")