In [None]:
# if True, just train a minimized Dataset
DEBUG = False

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow import keras
from tensorflow.data import Dataset
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM, LayerNormalization
from tensorflow.keras.callbacks import EarlyStopping

import os

# Create Model & Dataset

In [None]:
def select_data(csv, columns):
    timeseries = [group[columns].to_numpy() for index, group in csv.groupby(['breath_id'])]
    return np.array(timeseries)

def select_features(csv):
    return select_data(csv, ['breath_id', 'R', 'C','time_step', 'u_in', 'u_out'])

def select_labels(csv):
    return select_data(csv, ['pressure'])

In [None]:
earlyStopping = EarlyStopping(monitor='val_loss', patience=5)
model = Sequential([
    LayerNormalization(input_shape=(80, 6)),
    LSTM(64, return_sequences=True, dropout=0.3),
    LSTM(32, return_sequences=True, dropout=0.2),
    LSTM(1, activation='relu', return_sequences=True),
])

model.compile(optimizer='RMSprop', loss='mae', metrics=['mse'])
model.summary()

# Training

In [None]:
csv = pd.read_csv('/kaggle/input/ventilator-pressure-prediction/train.csv')
epochs = 50

if DEBUG:
    csv = csv.head(4*80)
    epochs = 10
    
features = select_features(csv)
labels = select_labels(csv)

X_train, X_test, y_train, y_test = train_test_split(features, labels)


model.fit(X_train, y_train, batch_size=512, epochs=epochs, validation_data=(X_test, y_test), callbacks=[earlyStopping])
model.save('model')

# Inference & Submission 

In [None]:
csv_predict = pd.read_csv('/kaggle/input/ventilator-pressure-prediction/test.csv')
if DEBUG:
    csv_predict = csv_predict.head(80 * 4)
    
prediction = model.predict(select_features(csv_predict), batch_size=512)

submission = pd.concat([csv_predict['id'], pd.Series(prediction.flatten())], axis=1)
submission.to_csv('submission.csv', index=False, header=['id', 'pressure'])