In [8]:
import math
import torch
import pytorch_lightning as pl
import torchmetrics
import torch.nn as nn
import numpy as np

class LSTM(nn.Module):
    def __init__(self, input_size=1, hidden_layer_size=100, output_size=12):
        super().__init__()
        self.hidden_layer_size = hidden_layer_size
        self.lstm = nn.LSTM(input_size, hidden_layer_size)
        self.linear = nn.Linear(hidden_layer_size, output_size)

        self.hidden_cell = (torch.zeros(1,1,self.hidden_layer_size),
                            torch.zeros(1,1,self.hidden_layer_size))
        
    def forward(self, input_seq):
        lstm_out, hidden_cell = self.lstm(input_seq.view(len(input_seq) ,1, -1), self.hidden_cell)
        predictions = self.linear(lstm_out.view(len(input_seq), -1))
        return predictions[-1]

    
class TimeStepVaryingLoss(nn.Module):
    """
    input_seq: sequence tensor
    target_seq: sequence tensor
    time_step_dict:
    {'mse': [start_time_step: int, end_time_step: int, weight: float], 
     'l1_loss': [start_time_step: int, end_time_step: int, weight: float]}
    """
    def __init__(self):
        super().__init__()
        self.loss_mse = nn.MSELoss()
        self.cross_entropy = nn.CrossEntropyLoss()
        self.l1_loss = nn.L1Loss()
        self.t_loss = []
        
    def forward(self, input_seq, target_seq, time_step_dict):
        for loss_key, values in time_step_dict.items():
            a = values[0]
            b = values[1]
            with torch.no_grad():
                w = torch.tensor(values[2])
            if loss_key =="mse":    
                self.t_loss.append(w*self.loss_mse(input_seq[a:b], target_seq[a:b]))
            elif loss_key =="cross-entropy":
                self.t_loss.append(w*self.cross_entropy(input_seq[a:b], target_seq[a:b]))
            elif loss_key =="l1-loss":
                self.t_loss.append(w*self.l1_loss(input_seq[a:b], target_seq[a:b]))
            else:
                raise NotImplementedError
        return sum(self.t_loss)
        
        
class LitWrappedLSTM(pl.LightningModule):
    """
    This implements a module which trains a MaskedLinear module.     
    Should be very simple.  Just give it a MaskedLinear module, and it will train it.
    """
    def __init__(self, input_size=784, hidden_size=100, output_size=10):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.model = LSTM()
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=10)
 
    def forward(self, input_seq):
        return self.model(input_seq)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        return optimizer

    # when you call this, you will likely want to turn off
    # the automatic batching since the default collate_fn
    # raises errors with variable length image sequences    
    # https://pytorch.org/docs/stable/data.html
    def training_step(self, train_item, batch_idx):
        x_sequence, y_sequence = train_item#['x'], train_item['y']
        self.model.hidden_cell = (torch.zeros(1, 1, self.model.hidden_layer_size),
                            torch.zeros(1, 1, self.model.hidden_layer_size))
        y_pred = self.model(x_sequence)
        tsv_loss = TimeStepVaryingLoss()
        sequence_loss  = tsv_loss(y_pred, y_sequence, {'mse': [0, 6, 1],'l1-loss': [6, 12, 1]})
        self.log('train_loss', sequence_loss, on_epoch=True, batch_size=1) # TODO: check the sequence_loss loss betwwen lstm and mlp.
        self.log('train_acc_step', self.accuracy)
        return sequence_loss


    # when you call this, you will likely want to turn off
    # the automatic batching since the default collate_fn
    # raises errors with variable length image sequences    
    # https://pytorch.org/docs/stable/data.html
    def validation_step(self, val_item, batch_idx):
        x_sequence, y_sequence = val_item
        self.model.hidden_cell = (torch.zeros(1, 1, self.model.hidden_layer_size),
                            torch.zeros(1, 1, self.model.hidden_layer_size))
        y_pred = self.model(x_sequence)
        tsv_loss = TimeStepVaryingLoss()
        sequence_loss  = tsv_loss(y_pred, y_sequence, {'mse': [0, 6, 1],'l1-loss': [6, 12, 1]})
        self.log('val_loss', sequence_loss, on_epoch=True, batch_size=1)

In [9]:
# dummy flights price data
all_data = "112. 118. 132. 129. 121. 135. 148. 148. 136. 119. 104. 118. 115. 126. 141. 135. 125. 149. 170. 170. 158. 133. 114. 140. 145. 150. 178. 163. 172. 178. 199. 199. 184. 162. 146. 166. 171. 180. 193. 181. 183. 218. 230. 242. 209. 191. 172. 194. 196. 196. 236. 235. 229. 243. 264. 272. 237. 211. 180. 201. 204. 188. 235. 227. 234. 264. 302. 293. 259. 229. 203. 229. 242. 233. 267. 269. 270. 315. 364. 347. 312. 274. 237. 278. 284. 277. 317. 313. 318. 374. 413. 405. 355. 306. 271. 306. 315. 301. 356. 348. 355. 422. 465. 467. 404. 347. 305. 336. 340. 318. 362. 348. 363. 435. 491. 505. 404. 359. 310. 337. 360. 342. 406. 396. 420. 472. 548. 559. 463. 407. 362. 405. 417. 391. 419. 461. 472. 535. 622. 606. 508. 461. 390. 432.".split(" ")
all_data = [float(i) for i in all_data]

In [10]:
test_data_size = 12

train_data = all_data[:-test_data_size]
test_data = all_data[-test_data_size:]

train_window = 12
train_data = torch.FloatTensor(train_data).view(-1)

def create_inout_sequences(input_data, tw):
    inout_seq = []
    L = len(input_data)
    for i in range(L-tw):
        train_seq = input_data[i:i+tw]
        train_label = input_data[i+tw:i+tw+tw]
        if len(train_label)==12:
            inout_seq.append((train_seq ,train_label))
    return inout_seq

train_inout_seq = create_inout_sequences(train_data, train_window)




In [11]:
train_inout_seq[:5]

[(tensor([112., 118., 132., 129., 121., 135., 148., 148., 136., 119., 104., 118.]),
  tensor([115., 126., 141., 135., 125., 149., 170., 170., 158., 133., 114., 140.])),
 (tensor([118., 132., 129., 121., 135., 148., 148., 136., 119., 104., 118., 115.]),
  tensor([126., 141., 135., 125., 149., 170., 170., 158., 133., 114., 140., 145.])),
 (tensor([132., 129., 121., 135., 148., 148., 136., 119., 104., 118., 115., 126.]),
  tensor([141., 135., 125., 149., 170., 170., 158., 133., 114., 140., 145., 150.])),
 (tensor([129., 121., 135., 148., 148., 136., 119., 104., 118., 115., 126., 141.]),
  tensor([135., 125., 149., 170., 170., 158., 133., 114., 140., 145., 150., 178.])),
 (tensor([121., 135., 148., 148., 136., 119., 104., 118., 115., 126., 141., 135.]),
  tensor([125., 149., 170., 170., 158., 133., 114., 140., 145., 150., 178., 163.]))]

In [12]:
torch_train_inout_seq = torch.utils.data.DataLoader(train_inout_seq, num_workers=1, batch_size=None)

In [13]:
trainer = pl.Trainer(gpus=None,
                         enable_progress_bar=False,
                         precision=16,
                         max_epochs=10)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [14]:
lit_model = LitWrappedLSTM()

In [15]:
trainer.fit(lit_model, torch_train_inout_seq, torch_train_inout_seq)


  | Name     | Type               | Params
------------------------------------------------
0 | model    | LSTM               | 42.4 K
1 | accuracy | MulticlassAccuracy | 0     
------------------------------------------------
42.4 K    Trainable params
0         Non-trainable params
42.4 K    Total params
0.170     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=10` reached.
