In [None]:
import pandas as pd
from concurrent.futures import ThreadPoolExecutor
import warnings
import os
from tqdm import tqdm
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset, random_split
from torch import nn
from pytorch_lightning import seed_everything, Trainer, loggers
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
import torch

warnings.filterwarnings('ignore')
seed_everything(42)
torch.set_float32_matmul_precision('medium')

Seed set to 42


42

In [7]:
df = pd.read_parquet("../data/input/series_train.parquet/id=0a418b57/part-0.parquet").sort_values("step")
df.head()

Unnamed: 0,step,X,Y,Z,enmo,anglez,non-wear_flag,light,battery_voltage,time_of_day,weekday,quarter,relative_date_PCIAT
0,0,-0.075242,-0.256743,-0.973791,0.038081,-72.952141,0.0,5.0,4202.0,51250000000000,2,4,-9.0
1,1,-0.265893,-0.270508,-0.76547,0.07743,-52.84922,0.0,0.5,4185.333496,51255000000000,2,4,-9.0
2,2,0.334517,-0.548602,-0.588596,0.039162,-44.118084,0.0,11.5,4185.5,51260000000000,2,4,-9.0
3,3,0.000193,-0.021069,-0.999681,0.00145,-88.759613,0.0,0.0,4185.666504,51265000000000,2,4,-9.0
4,4,-0.000685,-0.020681,-0.997677,0.000491,-88.756958,0.0,8.5,4185.833496,51270000000000,2,4,-9.0


In [None]:
class LSTM(pl.LightningModule):

    def __init__(self, hidden_size, seq_length, lr, criterion, num_classes=4):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.seq_length = seq_length
        self.lr = lr
        self.criterion = criterion
        self.lstm = nn.LSTM(input_size=self.seq_length, hidden_size=self.hidden_size, num_layers=1, batch_first=True)
        self.latent = nn.Linear(self.hidden_size, self.hidden_size)
        self.linear = nn.Linear(self.hidden_size, num_classes)

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        latent = self.latent(lstm_out[:, -1, :])
        y_pred = self.linear(latent)
        return latent, y_pred

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        # print("DEBUG ", y_pred.dtype, y.dtype)
        loss = self.criterion(y_pred, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        if TUNING: wandb.log({"train_loss": loss, "epoch": self.current_epoch})
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = self.criterion(y_pred, y)
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        if TUNING: wandb.log({"val_loss": loss, "epoch": self.current_epoch})
        return loss



    def predict_step(self, batch, batch_idx):
        x = batch
        y_pred = self(x)
        return y_pred

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.9, verbose=True)
        return {
                'optimizer': optimizer,
                'scheduler': scheduler,
                'monitor': 'val_loss_epoch'
                }

    def lr_scheduler_step(self, scheduler, metric):
        if metric is not None:
            scheduler.step(metric)
        else:
            scheduler.step()

In [9]:
def process_file(filename, dirname):
    data = pd.read_parquet(os.path.join(dirname, filename, 'part-0.parquet'))
    data = data.sort_values(by='step', ascending=True)
    data = data.drop('step', axis=1)
    data['id'] = filename.split('=')[1]
    return data

def load_time_series(dirname):
    ids = [f for f in os.listdir(dirname) if not f.startswith('.')]
    # print(f'{ids} files found')
    with ThreadPoolExecutor() as executor:
        results = list(tqdm(executor.map(lambda fname: process_file(fname, dirname), ids), total=len(ids)))
    return pd.concat(results)

In [10]:
train_parquet = load_time_series("../data/input/series_train.parquet")
# test_parquet = load_time_series("../data/input/series_test.parquet")
train_parquet.shape

  0%|          | 0/996 [00:00<?, ?it/s]

100%|██████████| 996/996 [00:06<00:00, 161.15it/s]


(314569149, 13)