Imports:

In [429]:
import pickle

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import mean_pinball_loss
from torch import nn, optim
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from torch.utils.data import DataLoader, Dataset

from consts import JULY, DEF_QUANTILES

Data loading:

In [430]:
with open('playground_input.pkl', 'rb') as f:
    data = pickle.load(f)

X, y = data['train']
val_X, val_y = data['val']

The fun stuff:

In [431]:
class SequenceDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        assert np.all(self.X.forecast_year.iloc[:-1].values <= self.X.forecast_year.iloc[1:].values), \
            'Error - not sorted by forecast year!'

        fy = self.X.forecast_year.iloc[idx]
        init_ind = (self.X.forecast_year == fy).argmax()
        # Create sequence from rows 0 to idx
        sequence = self.X.iloc[init_ind:idx + 1].drop(columns='forecast_year').values
        label = self.y.iloc[idx]
        return torch.tensor(sequence, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)


In [432]:
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size=128, num_layers=3, output_size=1):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, lengths):
        # Pack the padded sequences
        x_packed = pack_padded_sequence(x, lengths, batch_first=True)
        out_packed, _ = self.lstm(x_packed)
        out_padded, _ = pad_packed_sequence(out_packed, batch_first=True)
        # Apply the linear layer to the unpacked outputs
        out = self.fc(out_padded)
        return out[:, -1, :]  # Return the outputs for the last time step


In [433]:
def pad_collate_fn(batch):
    # Sort the batch by sequence length in descending order
    batch.sort(key=lambda x: len(x[0]), reverse=True)
    sequences, labels = zip(*batch)
    # Pad the sequences and stack the labels
    padded_sequences = pad_sequence(sequences, batch_first=True)
    lengths = [len(seq) for seq in sequences]
    labels = torch.stack(labels)
    return padded_sequences, labels, lengths


In [434]:
def features2seqs(X: pd.DataFrame, y: pd.Series, train: bool = True):
    X = X[X.date.dt.month <= JULY].drop(columns=['date'])
    if train:
        return SequenceDataset(X.iloc[:32], y.iloc[:32])

    raise NotImplementedError

In [435]:
bs = 32
lr = 1e-1

In [436]:
def quantile_loss(quantile: float):
    def qloss(y_true, y_pred):
        return torch.mean(torch.max(quantile * (y_true - y_pred), -(1 - quantile) * (y_true - y_pred)))

    return qloss


def avg_quantile_loss(y_true, y_pred):
    return torch.mean(torch.stack([quantile_loss(q)(y_true, y_pred) for q in DEF_QUANTILES]))

In [437]:
train_set = features2seqs(X, y)  # todo see we can overfit to a small training set before continuing
combined_X = pd.concat([X, val_X])
combined_y = pd.concat([y, val_y])
combined_set = features2seqs(combined_X, combined_y)

dataloader = DataLoader(train_set, batch_size=bs, shuffle=True, collate_fn=pad_collate_fn)

n_feats = train_set[0][0].shape[1]
model = LSTMModel(input_size=n_feats)

In [438]:
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()  # todo implement AQM loss, requires multioutput (use dummy std for starters)

In [439]:
X[X.date.dt.month <= JULY].drop(columns=['date', 'forecast_year'])

Unnamed: 0,oniANOM,oniTOTAL,max_height,min_height,mjo70E,SWE_volume_m3,percent_diff_over_1000,soi_sd,catchment_area,site_max_height_diff,...,percent_over_2000,ninoNINO1+2,mjo100E,med_height,percent_over_1000,ninoNINO3,ninoANOM.2,ninoANOM.3,ninoNINO4,time
12,-0.499490,-0.854725,0.953922,-1.469843,-1.190694,0.167086,1.998899,0.037966,-0.353614,3.11618,...,0.11328,-0.738090,-0.236802,0.140599,-0.359338,-1.399001,-0.355988,-0.581102,-0.416069,-1.678157
13,-0.533906,-0.875589,0.953922,-1.469843,0.015006,0.348750,1.998899,0.037966,-0.353614,3.11618,...,0.11328,-0.738090,-1.829944,0.140599,-0.359338,-1.399001,-0.355988,-0.581102,-0.416069,-1.611691
14,-0.568321,-0.896453,0.953922,-1.469843,2.314010,0.335406,1.998899,0.037966,-0.353614,3.11618,...,0.11328,-0.738090,-0.398133,0.140599,-0.359338,-1.399001,-0.355988,-0.581102,-0.416069,-1.545224
15,-0.563026,-0.843134,0.953922,-1.469843,1.118528,0.284683,1.998899,0.863044,-0.353614,3.11618,...,0.11328,0.261551,0.670683,0.140599,-0.359338,-1.076117,-0.624270,-0.730655,-0.973397,-1.478758
16,-0.555463,-0.766965,0.953922,-1.469843,-0.138261,0.374512,1.998899,0.863044,-0.353614,3.11618,...,0.11328,0.261551,1.477337,0.140599,-0.359338,-1.076117,-0.624270,-0.730655,-0.973397,-1.383807
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
851,-0.115778,0.493264,0.953922,-1.469843,-1.252000,-0.452350,1.998899,-0.787112,-0.353614,3.11618,...,0.11328,-0.788998,-0.156137,0.140599,-0.359338,-0.083841,0.180574,-0.293499,0.642856,-0.044987
852,-0.031366,0.471705,0.953922,-1.469843,-0.015647,-0.474310,1.998899,-0.787112,-0.353614,3.11618,...,0.11328,-0.788998,-0.761127,0.140599,-0.359338,-0.083841,0.180574,-0.293499,0.642856,0.040470
853,0.034288,0.454937,0.953922,-1.469843,0.015006,-0.477171,1.998899,-0.787112,-0.353614,3.11618,...,0.11328,-0.788998,-1.144288,0.140599,-0.359338,-0.083841,0.180574,-0.293499,0.642856,0.106936
854,0.099942,0.438168,0.953922,-1.469843,0.341975,-0.477171,1.998899,-0.787112,-0.353614,3.11618,...,0.11328,-0.788998,-0.942624,0.140599,-0.359338,-0.083841,0.180574,-0.293499,0.642856,0.173402


In [440]:
y.iloc[:32]

0    -0.032552
1    -0.032552
2    -0.032552
3    -0.032552
4    -0.032552
5    -0.032552
6    -0.032552
7    -0.032552
8    -0.032552
9    -0.032552
10   -0.032552
11   -0.032552
12   -0.032552
13   -0.032552
14   -0.032552
15   -0.032552
16   -0.032552
17   -0.032552
18   -0.032552
19   -0.032552
20   -0.032552
21   -0.032552
22   -0.032552
23   -0.032552
24   -0.032552
25   -0.032552
26   -0.032552
27   -0.032552
28   -1.063455
29   -1.063455
30   -1.063455
31   -1.063455
Name: volume, dtype: float64

In [441]:
num_epochs = 250
for epoch in range(num_epochs):
    for sequences, labels, lengths in dataloader:
        optimizer.zero_grad()
        outputs = model(sequences, lengths)
        outputs = outputs.squeeze()  # todo remove/change when using a multioutput
        # Ensure labels are also squeezed to match output shape
        loss = criterion(labels, outputs)
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

Epoch [1/250], Loss: 0.1461
Epoch [2/250], Loss: 2.7061
Epoch [3/250], Loss: 0.2445
Epoch [4/250], Loss: 0.2809
Epoch [5/250], Loss: 0.2125
Epoch [6/250], Loss: 0.1738
Epoch [7/250], Loss: 0.1435
Epoch [8/250], Loss: 0.1266
Epoch [9/250], Loss: 0.1231
Epoch [10/250], Loss: 0.1285
Epoch [11/250], Loss: 0.1363
Epoch [12/250], Loss: 0.1416
Epoch [13/250], Loss: 0.1422
Epoch [14/250], Loss: 0.1387
Epoch [15/250], Loss: 0.1334
Epoch [16/250], Loss: 0.1287
Epoch [17/250], Loss: 0.1259
Epoch [18/250], Loss: 0.1249
Epoch [19/250], Loss: 0.1248
Epoch [20/250], Loss: 0.1245
Epoch [21/250], Loss: 0.1237
Epoch [22/250], Loss: 0.1224
Epoch [23/250], Loss: 0.1209
Epoch [24/250], Loss: 0.1199
Epoch [25/250], Loss: 0.1198
Epoch [26/250], Loss: 0.1204
Epoch [27/250], Loss: 0.1212
Epoch [28/250], Loss: 0.1215
Epoch [29/250], Loss: 0.1211
Epoch [30/250], Loss: 0.1199
Epoch [31/250], Loss: 0.1183
Epoch [32/250], Loss: 0.1169
Epoch [33/250], Loss: 0.1162
Epoch [34/250], Loss: 0.1164
Epoch [35/250], Loss: 0

Empirical quantiles training loss:

In [442]:
np.mean([mean_pinball_loss(y, [y.quantile(q)] * len(y), alpha=q) for q in DEF_QUANTILES])

0.2405846650599108