In [6]:
from data import OptiverDataset
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from tqdm.notebook import tqdm
import os

In [7]:
class OptiverLSTM(nn.Module):
    def __init__(self, hidden_size=128, layers=2, dropout=0.2):
        super().__init__()
        self.lstm = nn.LSTM(input_size=14, 
                            hidden_size=hidden_size, 
                            num_layers=layers, 
                            batch_first=True, 
                            dropout=dropout,
                            proj_size=1
                           )
    def forward(self, x):
        x, _ = self.lstm(x)
        return x

In [8]:
def collate_fn(data):
    return torch.stack([item[0] for item in data]), torch.stack([item[1] for item in data])

In [5]:
def train_lstm(hidden_size, layers, dropout, num_epochs=5): 
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = OptiverLSTM(hidden_size=hidden_size, layers=layers, dropout=dropout)
    model = model.to(device)
    
    batch_size = 1
    
    train_dataset = OptiverDataset(split='train')
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
    val_dataset = OptiverDataset(split='val')
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
    loss_fn = nn.L1Loss()
    optimizer = optim.AdamW(model.parameters(), lr=0.001)
    
    maes = []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}')
        model.train()
        progress = tqdm(total=len(train_dataloader), desc='Training')
        for x, y in train_dataloader:
            x, y = x.to(device), y.to(device)
            y_pred = model(x).squeeze(dim=2)
            
            loss = loss_fn(y_pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            progress.set_postfix({'loss': loss.item()})
            progress.update()
        progress.close()
                
        model.eval()
        loss = 0
        num_items = 0
        progress = tqdm(total=len(val_dataloader), desc='Validating')
        for x, y in val_dataloader:
            x, y = x.to(device), y.to(device)
            y_pred = model(x).squeeze(dim=2)
            loss += torch.sum(torch.absolute(y_pred - y)).item()
            num_items += torch.numel(y_pred)
            progress.update()
        progress.close()
        mae = loss/num_items
        maes.append(mae)
        print(f'MAE = {mae}')
    return min(maes)


In [19]:
import pandas as pd
hidden_sizes = [64, 128, 256]
layers = [2, 4, 6, 8]
dropouts = [0.0, 0.1, 0.2, 0.3, 0.4]

hidden_size_df = pd.DataFrame(columns=['hidden_size', 'mae'])
layer_df = pd.DataFrame(columns=['num_layers', 'mae'])
dropout_df = pd.DataFrame(columns=['dropout', 'mae'])

for hidden_size in hidden_sizes:
    mae = train_lstm(hidden_size=hidden_size, layers=4, dropout=0)
    hidden_size_df = hidden_size_df.append({'hidden_size': hidden_size, 'mae': mae}, ignore_index=True)

for num_layers in layers:
    mae = train_lstm(hidden_size=128, layers=num_layers, dropout=0)
    layer_df = layer_df.append({'num_layers': num_layers, 'mae': mae}, ignore_index=True)

for dropout in dropouts:
    mae = train_lstm(hidden_size=128, layers=4, dropout=dropout)
    dropout_df = dropout_df.append({'dropout': dropout, 'mae': mae}, ignore_index=True)

hidden_size_df
layer_df
dropout_df


Loading data:   0%|          | 0/139 [00:00<?, ?it/s][A
Loading data:  11%|█         | 15/139 [00:00<00:00, 149.03it/s][A
Loading data:  31%|███       | 43/139 [00:00<00:00, 225.35it/s][A
Loading data:  52%|█████▏    | 72/139 [00:00<00:00, 251.13it/s][A
Loading data:  73%|███████▎  | 101/139 [00:00<00:00, 264.86it/s][A
Loading data: 100%|██████████| 139/139 [00:00<00:00, 256.79it/s][A

Loading data: 100%|██████████| 20/20 [00:00<00:00, 214.17it/s]

Epoch 1





Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.014926160828354
Epoch 2


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.014386511175123
Epoch 3


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.009135658512161
Epoch 4


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 5.992236954751979
Epoch 5


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 5.983187748249347



Loading data:   0%|          | 0/139 [00:00<?, ?it/s][A
Loading data:  17%|█▋        | 24/139 [00:00<00:00, 234.47it/s][A
Loading data:  37%|███▋      | 52/139 [00:00<00:00, 260.93it/s][A
Loading data:  58%|█████▊    | 81/139 [00:00<00:00, 272.49it/s][A
Loading data:  79%|███████▉  | 110/139 [00:00<00:00, 276.26it/s][A
Loading data: 100%|██████████| 139/139 [00:00<00:00, 270.98it/s][A

Loading data: 100%|██████████| 20/20 [00:00<00:00, 214.94it/s]

Epoch 1





Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015556671393511
Epoch 2


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.0153197163413505
Epoch 3


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.01493250214329
Epoch 4


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015402922648265
Epoch 5


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.0151633222546685



Loading data:   0%|          | 0/139 [00:00<?, ?it/s][A
Loading data:  17%|█▋        | 24/139 [00:00<00:00, 236.49it/s][A
Loading data:  38%|███▊      | 53/139 [00:00<00:00, 264.61it/s][A
Loading data:  59%|█████▉    | 82/139 [00:00<00:00, 275.42it/s][A
Loading data: 100%|██████████| 139/139 [00:00<00:00, 275.98it/s][A

Loading data: 100%|██████████| 20/20 [00:00<00:00, 218.10it/s]

Epoch 1





Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015254821050344
Epoch 2


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.014946329211654
Epoch 3


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.011245383972998
Epoch 4


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 5.990258727150254
Epoch 5


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 5.99236582603094



Loading data:   0%|          | 0/139 [00:00<?, ?it/s][A
Loading data:  17%|█▋        | 24/139 [00:00<00:00, 238.40it/s][A
Loading data:  38%|███▊      | 53/139 [00:00<00:00, 264.98it/s][A
Loading data:  59%|█████▉    | 82/139 [00:00<00:00, 275.27it/s][A
Loading data: 100%|██████████| 139/139 [00:00<00:00, 275.68it/s][A

Loading data: 100%|██████████| 20/20 [00:00<00:00, 232.01it/s]

Epoch 1





Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.007806984181684
Epoch 2


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 5.994216439359919
Epoch 3


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 5.986889894210357
Epoch 4


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 5.980637245116662
Epoch 5


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 5.977934231483661



Loading data:   0%|          | 0/139 [00:00<?, ?it/s][A
Loading data:  17%|█▋        | 24/139 [00:00<00:00, 235.18it/s][A
Loading data:  38%|███▊      | 53/139 [00:00<00:00, 262.44it/s][A
Loading data:  59%|█████▉    | 82/139 [00:00<00:00, 273.57it/s][A
Loading data: 100%|██████████| 139/139 [00:00<00:00, 274.22it/s][A

Loading data: 100%|██████████| 20/20 [00:00<00:00, 227.12it/s]

Epoch 1





Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015097901647631
Epoch 2


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.0150826674827576
Epoch 3


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015090903687659
Epoch 4


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015091935558432
Epoch 5


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.01501567092763



Loading data:   0%|          | 0/139 [00:00<?, ?it/s][A
Loading data:  17%|█▋        | 24/139 [00:00<00:00, 232.61it/s][A
Loading data:  38%|███▊      | 53/139 [00:00<00:00, 260.76it/s][A
Loading data:  59%|█████▉    | 82/139 [00:00<00:00, 272.10it/s][A
Loading data: 100%|██████████| 139/139 [00:00<00:00, 273.69it/s][A

Loading data: 100%|██████████| 20/20 [00:00<00:00, 222.43it/s]

Epoch 1





Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.014916648855952
Epoch 2


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.014880852320756
Epoch 3


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015157281120322
Epoch 4


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015135836969521
Epoch 5


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.0152200188633485



Loading data:   0%|          | 0/139 [00:00<?, ?it/s][A
Loading data:  16%|█▌        | 22/139 [00:00<00:00, 218.27it/s][A
Loading data:  36%|███▌      | 50/139 [00:00<00:00, 254.33it/s][A
Loading data:  57%|█████▋    | 79/139 [00:00<00:00, 268.33it/s][A
Loading data:  78%|███████▊  | 108/139 [00:00<00:00, 275.13it/s][A
Loading data: 100%|██████████| 139/139 [00:00<00:00, 268.73it/s][A

Loading data: 100%|██████████| 20/20 [00:00<00:00, 207.62it/s]

Epoch 1





Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015013044347479
Epoch 2


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.0150899093394585
Epoch 3


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.01525917366888
Epoch 4


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.014900945658908
Epoch 5


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.01501578349535



Loading data:   0%|          | 0/139 [00:00<?, ?it/s][A
Loading data:  18%|█▊        | 25/139 [00:00<00:00, 249.81it/s][A
Loading data:  39%|███▉      | 54/139 [00:00<00:00, 271.00it/s][A
Loading data:  60%|██████    | 84/139 [00:00<00:00, 280.44it/s][A
Loading data: 100%|██████████| 139/139 [00:00<00:00, 280.69it/s][A

Loading data: 100%|██████████| 20/20 [00:00<00:00, 232.56it/s]

Epoch 1





Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015122310081745
Epoch 2


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015072667716899
Epoch 3


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.014986984920128
Epoch 4


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015131934621868
Epoch 5


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015279792323062



Loading data:   0%|          | 0/139 [00:00<?, ?it/s][A
Loading data:  19%|█▊        | 26/139 [00:00<00:00, 252.58it/s][A
Loading data:  40%|███▉      | 55/139 [00:00<00:00, 271.85it/s][A
Loading data:  61%|██████    | 85/139 [00:00<00:00, 281.15it/s][A
Loading data: 100%|██████████| 139/139 [00:00<00:00, 280.99it/s][A

Loading data: 100%|██████████| 20/20 [00:00<00:00, 238.92it/s]

Epoch 1





Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.014920363590736
Epoch 2


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015078314864223
Epoch 3


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015180751490097
Epoch 4


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015030098357172
Epoch 5


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.0150234381003616



Loading data:   0%|          | 0/139 [00:00<?, ?it/s][A
Loading data:  18%|█▊        | 25/139 [00:00<00:00, 249.06it/s][A
Loading data:  39%|███▉      | 54/139 [00:00<00:00, 271.16it/s][A
Loading data:  60%|██████    | 84/139 [00:00<00:00, 281.13it/s][A
Loading data: 100%|██████████| 139/139 [00:00<00:00, 281.31it/s][A

Loading data: 100%|██████████| 20/20 [00:00<00:00, 242.63it/s]

Epoch 1





Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.01529680881018
Epoch 2


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.01001327398563
Epoch 3


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 5.993395482972557
Epoch 4


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 5.988993372162693
Epoch 5


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 5.984555689953075



Loading data:   0%|          | 0/139 [00:00<?, ?it/s][A
Loading data:  19%|█▊        | 26/139 [00:00<00:00, 252.27it/s][A
Loading data:  40%|███▉      | 55/139 [00:00<00:00, 271.12it/s][A
Loading data:  61%|██████    | 85/139 [00:00<00:00, 280.69it/s][A
Loading data: 100%|██████████| 139/139 [00:00<00:00, 280.95it/s][A

Loading data: 100%|██████████| 20/20 [00:00<00:00, 241.04it/s]

Epoch 1





Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.014942614476869
Epoch 2


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.01499492094444
Epoch 3


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015059103306549
Epoch 4


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015037133839718
Epoch 5


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015224183869016



Loading data:   0%|          | 0/139 [00:00<?, ?it/s][A
Loading data:  18%|█▊        | 25/139 [00:00<00:00, 246.88it/s][A
Loading data:  39%|███▉      | 54/139 [00:00<00:00, 269.40it/s][A
Loading data:  60%|██████    | 84/139 [00:00<00:00, 280.11it/s][A
Loading data: 100%|██████████| 139/139 [00:00<00:00, 279.63it/s][A

Loading data: 100%|██████████| 20/20 [00:00<00:00, 225.07it/s]

Epoch 1





Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015007284632435
Epoch 2


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.015053906430109
Epoch 3


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.014929950608286
Epoch 4


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 6.011546727761421
Epoch 5


Training:   0%|          | 0/139 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

MAE = 5.999076419373895


Unnamed: 0,dropout,mae
0,0.0,6.014987
1,0.1,6.01492
2,0.2,5.984556
3,0.3,6.014943
4,0.4,5.999076


In [20]:
hidden_size_df

Unnamed: 0,hidden_size,mae
0,64.0,5.983188
1,128.0,6.014933
2,256.0,5.990259


In [21]:
layer_df

Unnamed: 0,num_layers,mae
0,2.0,5.977934
1,4.0,6.015016
2,6.0,6.014881
3,8.0,6.014901


In [22]:
dropout_df

Unnamed: 0,dropout,mae
0,0.0,6.014987
1,0.1,6.01492
2,0.2,5.984556
3,0.3,6.014943
4,0.4,5.999076


In [None]:
def test():
    model = OptiverLSTM(hidden_size=64, layers=2, dropout=0.25)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)
    train_dataset = OptiverDataset(split='train_val')
    train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
    loss_fn = nn.L1Loss()
    optimizer = optim.AdamW(model.parameters(), lr=0.001)
    model.train()
    for i in range(5):
        progress = tqdm(total=len(train_dataloader), desc='Training')
        for x, y in train_dataloader:
            x, y = x.to(device), y.to(device)
            y_pred = model(x).squeeze(dim=2)
            loss = loss_fn(y_pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            progress.set_postfix({'loss': loss.item()})
            progress.update()
        progress.close()
        
    test_dataset = OptiverDataset(split='test')
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
    
    model = model.to(device)
    model.eval()
    loss = 0
    num_items = 0
    progress = tqdm(total=len(test_dataloader), desc='Testing')
    for x, y in test_dataloader:
        x, y = x.to(device), y.to(device)
        y_pred = model(x).squeeze(dim=2)
        loss += torch.sum(torch.absolute(y_pred - y)).item()
        num_items += torch.numel(y_pred)
        progress.update()
    progress.close()
    mae = loss/num_items
    return mae
mae = test()

Loading data: 100%|██████████| 159/159 [00:00<00:00, 281.48it/s]


Training:   0%|          | 0/159 [00:00<?, ?it/s]

In [None]:
mae