In [1]:
import torch # to create tensors to store the raw data, weights and biases 
import torch.nn as nn # to make the weights and biases learnable (part of the network)
import torch.nn.functional as F # to apply activation functions 
from torch.optim import Adam  

import lightning as L # to train the model 
from torch.utils.data import DataLoader, TensorDataset # to load the data 
from lightning.pytorch.tuner.tuning import Tuner      
from lightning.pytorch.callbacks import ModelCheckpoint #Lightning is trying to delete the previous checkpoint (because the default ModelCheckpoint is set to save_top_k=1).
                                                        # On Windows a file gets locked as soon as any program (Explorer preview, antivirus, TensorBoard, VS Code, …) opens it, and Windows then blocks the delete call ⇒ PermissionError WinError 32.


import matplotlib.pyplot as plt # graphs 
import seaborn as sns # graphs


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from typing import override


class LSTMbyHand(L.LightningModule):

    def __init__(self): # init the weights and biases of the LSTM from a normal distribution

        super().__init__()
        mean = torch.tensor(0.0)
        std = torch.tensor(1.0)

        self.wlr1 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.wlr2 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.blr1 = nn.Parameter(torch.tensor(0.0), requires_grad=True)

        self.wpr1 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.wpr2 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.bpr1 = nn.Parameter(torch.tensor(0.0), requires_grad=True)

        self.wp1 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.wp2 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.bp1 = nn.Parameter(torch.tensor(0.0), requires_grad=True)

        self.wo1 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.wo2 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.bo1 = nn.Parameter(torch.tensor(0.0), requires_grad=True)

    def lstm_unit(self, input_value, long_memory, short_memory): # the lstm unit with the 3 gates / phases 
        
        long_remember_percent = torch.sigmoid(short_memory * self.wlr1 + input_value * self.wlr2 + self.blr1 )
        
        potential_long_memory_remember_percent = torch.sigmoid(short_memory * self.wpr1 + input_value * self.wpr2 + self.bpr1)
        
        potential_long_memory = torch.tanh(short_memory * self.wp1 + input_value * self.wp2 + self.bp1)
        
        updated_long_memory = (long_memory * long_remember_percent) + (potential_long_memory * potential_long_memory_remember_percent)

        output_percent = torch.sigmoid(short_memory * self.wo1 + input_value * self.wo2 + self.bo1)

        updated_short_memory = torch.tanh(updated_long_memory) * output_percent

        return ([updated_long_memory, updated_short_memory])

    def forward(self, input): # make a forward pass throught the unrolled LSTM

        long_memory=0
        short_memory=0

        day1= input[0]
        day2= input[1]
        day3= input[2]
        day4= input[3]

        long_memory, short_memory = self.lstm_unit(day1, long_memory, short_memory)
        long_memory, short_memory = self.lstm_unit(day2, long_memory, short_memory)
        long_memory, short_memory = self.lstm_unit(day3, long_memory, short_memory)
        long_memory, short_memory = self.lstm_unit(day4, long_memory, short_memory)

        return short_memory    

    def configure_optimizers(self): 
        return Adam(self.parameters())

    def training_step(self, batch, batch_idx): # calc loss and log training loss process
        input_i,label_i = batch
        output_i=self.forward(input_i[0])

        loss = (output_i-label_i)**2

        self.log("train_loss", loss) # part of lightning ; it we create new file to log the loss 
        
        if(label_i == 0):
            self.log("out_0", output_i)
        else:
            self.log("out_1", output_i)

        return loss


