In [1]:
import pytorch_lightning as ptl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
import polars as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from typing import Tuple
import torch.optim as optim
from model.peephole_lstm import PeepholeLSTM
from model.data_loader import load_data, prepare_data_loaders

In [2]:
data: pl.DataFrame = load_data("/home/pupperemeritus/isro_project/data/October2023.parquet")

In [4]:
data.shape

(2861526, 67)

In [3]:
data.head(20)

GPS_WN,GPS_TOW,SVID,Value of the RxState field of the ReceiverStatus SBF block,Azimuth,Elevation,Average Sig1 C/N0 over the last minute (dB-Hz),S4,Correction to total S4 on Sig1 (thermal noise component only) (dimensionless),"Phi01 on Sig1, 1-second phase sigma (radians)","Phi03 on Sig1, 3-second phase sigma (radians)","Phi10 on Sig1, 10-second phase sigma (radians)","Phi30 on Sig1, 30-second phase sigma (radians)","Phi60 on Sig1, 60-second phase sigma (radians)","AvgCCD on Sig1, average of code/carrier divergence (meters)","SigmaCCD on Sig1, standard deviation of code/carrier divergence (meters)",TEC at TOW - 45 seconds (TECU),dTEC from TOW - 60s to TOW - 45s (TECU),TEC at TOW - 30 seconds (TECU),dTEC from TOW - 45s to TOW - 30s (TECU),TEC at TOW - 15 seconds (TECU),dTEC from TOW - 30s to TOW - 15s (TECU),TEC at TOW (TECU),dTEC from TOW - 15s to TOW (TECU),Sig1 lock time (seconds),sbf2ismr version number,Lock time on the second frequency used for the TEC computation (seconds),Averaged C/N0 of second frequency used for the TEC computation (dB-Hz),SI Index on Sig1: (10*log10(Pmax)-10*log10(Pmin))/(10*log10(Pmax)+10*log10(Pmin)) (dimensionless),"SI Index on Sig1, numerator only: 10*log10(Pmax)-10*log10(Pmin) (dB)","p on Sig1, spectral slope of detrended phase in the 0.1 to 25Hz range (dimensionless)",Average Sig2 C/N0 over the last minute (dB-Hz),Total S4 on Sig2 (dimensionless),Correction to total S4 on Sig2 (thermal noise component only) (dimensionless),"Phi01 on Sig2, 1-second phase sigma (radians)","Phi03 on Sig2, 3-second phase sigma (radians)","Phi10 on Sig2, 10-second phase sigma (radians)","Phi30 on Sig2, 30-second phase sigma (radians)","Phi60 on Sig2, 60-second phase sigma (radians)","AvgCCD on Sig2, average of code/carrier divergence (meters)","SigmaCCD on Sig2, standard deviation of code/carrier divergence (meters)",Sig2 lock time (seconds),SI Index on Sig2 (dimensionless),"SI Index on Sig2, numerator only (dB)","p on Sig2, phase spectral slope in the 0.1 to 25Hz range (dimensionless)",Average Sig3 C/N0 over the last minute (dB-Hz),Total S4 on Sig3 (dimensionless),Correction to total S4 on Sig3 (thermal noise component only) (dimensionless),"Phi01 on Sig3, 1-second phase sigma (radians)","Phi03 on Sig3, 3-second phase sigma (radians)","Phi10 on Sig3, 10-second phase sigma (radians)","Phi30 on Sig3, 30-second phase sigma (radians)","Phi60 on Sig3, 60-second phase sigma (radians)","AvgCCD on Sig3, average of code/carrier divergence (meters)","SigmaCCD on Sig3, standard deviation of code/carrier divergence (meters)",Sig3 lock time (seconds),SI Index on Sig3 (dimensionless),"SI Index on Sig3, numerator only (dB)","p on Sig3, phase spectral slope in the 0.1 to 25Hz range (dimensionless)","T on Sig1, phase power spectral density at 1 Hz (rad^2/Hz)","T on Sig2, phase power spectral density at 1 Hz (rad^2/Hz)",Vertical Scintillation Amplitude,Vertical Scintillation Phase,"T on Sig3, phase power spectral density at 1 Hz (rad^2/Hz)",Latitude,Longitude,IST_Time
i64,i64,i64,i64,i64,i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i64,i64,i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,datetime[μs]
2282,60,5,630,17,67,49.4,0.027,0.034,0.017,0.031,0.04,0.04,0.041,-14.596,0.004,12.61,0.01,12.572,0.007,12.591,0.008,12.572,0.016,13911,781,7834,57.4,0.007,0.782,1.8,48.7,0.035,0.037,0.017,0.03,0.039,0.039,0.04,-25.402,0.007,7834,0.01,0.93,1.66,,,,,,,,,,,,,,,0.000027,0.00003,0.020354,0.03827,,18.777651,78.758218,2023-10-01 05:31:00
2282,60,11,630,47,20,36.2,0.124,0.156,0.039,0.041,0.043,0.043,0.043,-30.136,0.018,19.332,0.089,19.37,0.118,19.77,0.109,19.989,0.078,22104,781,22100,22.1,0.039,3.334,1.31,39.5,0.109,0.106,0.036,0.042,0.054,0.056,0.058,-48.513,0.05,256,0.038,2.975,1.51,42.2,0.079,0.077,0.027,0.037,0.056,0.06,0.069,-53.368,0.066,22143,0.026,2.3,1.45,0.000137,0.000148,0.109271,0.041292,0.000081,22.710547,84.609833,2023-10-01 05:31:00
2282,60,12,630,195,53,48.9,0.031,0.036,0.015,0.02,0.022,0.023,0.023,-5.513,0.011,14.542,0.023,14.637,0.02,14.656,0.029,14.732,0.024,15353,781,7834,54.9,0.008,0.91,1.83,46.5,0.047,0.047,0.017,0.023,0.028,0.029,0.029,-9.81,0.018,7834,0.014,1.32,1.76,,,,,,,,,,,,,,,0.00003,0.000046,0.007094,0.016135,,14.927063,77.62753,2023-10-01 05:31:00
2282,60,13,630,124,28,41.3,0.086,0.086,0.035,0.054,0.081,0.087,0.098,-10.845,0.002,24.417,-0.006,24.551,0.004,24.646,0.023,24.779,-0.004,12400,781,6250,27.9,0.026,2.475,1.77,,,,,,,,,,,,,,,,,,,,,,,,,,,,,0.000182,,0.014132,0.062675,,14.027978,83.352908,2023-10-01 05:31:00
2282,60,15,630,169,32,43.7,0.075,0.065,0.027,0.034,0.044,0.046,0.047,-7.634,0.011,20.718,-0.017,20.603,-0.043,20.527,-0.026,20.622,-0.023,6128,781,6102,45.5,0.021,2.063,1.53,41.9,0.083,0.08,0.028,0.033,0.039,0.04,0.04,-12.534,0.019,6121,0.027,2.312,1.61,,,,,,,,,,,,,,,0.000086,0.000102,0.031299,0.03677,,12.342354,79.312619,2023-10-01 05:31:00
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
2282,60,47,630,230,30,44.2,0.064,0.062,0.025,0.035,0.055,0.056,0.056,-4.503,0.014,,,,,,,,,11636,781,,,0.019,1.704,1.73,,,,,,,,,,,,,,,,,,,,,,,,,,,,,0.000085,,0.062645,0.055698,,13.806814,73.969906,2023-10-01 05:31:00
2282,60,48,630,291,24,47.7,0.059,0.041,0.031,0.047,0.067,0.071,0.078,-1.968,0.005,56.012,0.017,56.119,0.031,56.363,-0.012,56.334,0.016,4878,781,4875,46.2,0.016,1.564,2.14,46.2,0.087,0.049,0.024,0.033,0.043,0.044,0.045,-3.296,0.012,4875,0.024,2.211,1.85,,,,,,,,,,,,,,,0.000132,0.000085,0.047788,0.074635,,19.718511,71.541273,2023-10-01 05:31:00
2282,60,57,630,30,28,38.0,0.124,0.126,0.047,0.061,0.085,0.087,0.09,-24.469,0.016,17.959,-0.057,17.91,0.019,18.36,0.023,18.213,0.063,16684,781,7834,38.8,0.045,3.387,1.7,38.8,0.122,0.115,0.042,0.048,0.063,0.063,0.064,-36.836,0.015,7834,0.041,3.2,1.61,,,,,,,,,,,,,,,0.000313,0.000221,0.021692,0.057559,,22.476348,81.500818,2023-10-01 05:31:00
2282,60,58,630,327,19,44.6,0.113,0.059,0.038,0.082,0.114,0.116,0.118,-28.491,0.01,30.414,-0.032,30.16,-0.024,30.179,0.003,30.082,-0.037,9333,781,7834,43.3,0.031,2.72,1.78,43.3,0.136,0.069,0.035,0.073,0.098,0.1,0.103,-40.75,0.019,7834,0.035,3.044,1.68,,,,,,,,,,,,,,,0.000098,0.000099,0.013305,0.06967,,24.267627,73.375128,2023-10-01 05:31:00


In [6]:
data.describe(interpolation="linear")

statistic,GPS_WN,GPS_TOW,SVID,Value of the RxState field of the ReceiverStatus SBF block,Azimuth,Elevation,Average Sig1 C/N0 over the last minute (dB-Hz),S4,Correction to total S4 on Sig1 (thermal noise component only) (dimensionless),"Phi01 on Sig1, 1-second phase sigma (radians)","Phi03 on Sig1, 3-second phase sigma (radians)","Phi10 on Sig1, 10-second phase sigma (radians)","Phi30 on Sig1, 30-second phase sigma (radians)","Phi60 on Sig1, 60-second phase sigma (radians)","AvgCCD on Sig1, average of code/carrier divergence (meters)","SigmaCCD on Sig1, standard deviation of code/carrier divergence (meters)",TEC at TOW - 45 seconds (TECU),dTEC from TOW - 60s to TOW - 45s (TECU),TEC at TOW - 30 seconds (TECU),dTEC from TOW - 45s to TOW - 30s (TECU),TEC at TOW - 15 seconds (TECU),dTEC from TOW - 30s to TOW - 15s (TECU),TEC at TOW (TECU),dTEC from TOW - 15s to TOW (TECU),Sig1 lock time (seconds),sbf2ismr version number,Lock time on the second frequency used for the TEC computation (seconds),Averaged C/N0 of second frequency used for the TEC computation (dB-Hz),SI Index on Sig1: (10*log10(Pmax)-10*log10(Pmin))/(10*log10(Pmax)+10*log10(Pmin)) (dimensionless),"SI Index on Sig1, numerator only: 10*log10(Pmax)-10*log10(Pmin) (dB)","p on Sig1, spectral slope of detrended phase in the 0.1 to 25Hz range (dimensionless)",Average Sig2 C/N0 over the last minute (dB-Hz),Total S4 on Sig2 (dimensionless),Correction to total S4 on Sig2 (thermal noise component only) (dimensionless),"Phi01 on Sig2, 1-second phase sigma (radians)","Phi03 on Sig2, 3-second phase sigma (radians)","Phi10 on Sig2, 10-second phase sigma (radians)","Phi30 on Sig2, 30-second phase sigma (radians)","Phi60 on Sig2, 60-second phase sigma (radians)","AvgCCD on Sig2, average of code/carrier divergence (meters)","SigmaCCD on Sig2, standard deviation of code/carrier divergence (meters)",Sig2 lock time (seconds),SI Index on Sig2 (dimensionless),"SI Index on Sig2, numerator only (dB)","p on Sig2, phase spectral slope in the 0.1 to 25Hz range (dimensionless)",Average Sig3 C/N0 over the last minute (dB-Hz),Total S4 on Sig3 (dimensionless),Correction to total S4 on Sig3 (thermal noise component only) (dimensionless),"Phi01 on Sig3, 1-second phase sigma (radians)","Phi03 on Sig3, 3-second phase sigma (radians)","Phi10 on Sig3, 10-second phase sigma (radians)","Phi30 on Sig3, 30-second phase sigma (radians)","Phi60 on Sig3, 60-second phase sigma (radians)","AvgCCD on Sig3, average of code/carrier divergence (meters)","SigmaCCD on Sig3, standard deviation of code/carrier divergence (meters)",Sig3 lock time (seconds),SI Index on Sig3 (dimensionless),"SI Index on Sig3, numerator only (dB)","p on Sig3, phase spectral slope in the 0.1 to 25Hz range (dimensionless)","T on Sig1, phase power spectral density at 1 Hz (rad^2/Hz)","T on Sig2, phase power spectral density at 1 Hz (rad^2/Hz)",Vertical Scintillation Amplitude,Vertical Scintillation Phase,"T on Sig3, phase power spectral density at 1 Hz (rad^2/Hz)",Latitude,Longitude,IST_Time
str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,str
"""count""",2861526.0,2861526.0,2861526.0,2861526.0,2861526.0,2861526.0,2861526.0,2861526.0,2861526.0,2861193.0,2860840.0,2859392.0,2855880.0,2851929.0,2861526.0,2861526.0,2048120.0,2048120.0,2047947.0,2035740.0,2047866.0,2035561.0,2064552.0,2047866.0,2856186.0,2861526.0,2100334.0,2141623.0,2861525.0,2861525.0,2848677.0,2058520.0,2057873.0,2057873.0,2057133.0,2056179.0,2051080.0,2035820.0,2019739.0,2058519.0,2058519.0,2018553.0,2058772.0,2058772.0,1992610.0,1854944.0,1854909.0,1854909.0,1854717.0,1854553.0,1854001.0,1852861.0,1851745.0,1854943.0,1854943.0,1854101.0,1854951.0,1854951.0,1850672.0,2848677.0,1992610.0,2861526.0,2861526.0,1850672.0,2861526.0,2861526.0,"""2861526"""
"""null_count""",0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,333.0,686.0,2134.0,5646.0,9597.0,0.0,0.0,813406.0,813406.0,813579.0,825786.0,813660.0,825965.0,796974.0,813660.0,5340.0,0.0,761192.0,719903.0,1.0,1.0,12849.0,803006.0,803653.0,803653.0,804393.0,805347.0,810446.0,825706.0,841787.0,803007.0,803007.0,842973.0,802754.0,802754.0,868916.0,1006582.0,1006617.0,1006617.0,1006809.0,1006973.0,1007525.0,1008665.0,1009781.0,1006583.0,1006583.0,1007425.0,1006575.0,1006575.0,1010854.0,12849.0,868916.0,0.0,0.0,1010854.0,0.0,0.0,"""0"""
"""mean""",2283.74349,285852.412119,115.527053,630.0,156.216331,35.561843,42.27588,0.126852,0.098654,0.04335,0.059783,0.073823,0.075742,0.07651,-4.692243,0.863711,49.573304,-0.000101,49.591379,0.001139,49.638422,0.002616,48.233404,-0.001537,18365.144978,781.0,11856.566998,41.623682,0.044002,3.627524,1.632307,42.892343,0.121018,0.0871,0.03971,0.057793,0.073562,0.074555,0.074335,-5.1725,0.095441,12271.251307,0.041567,3.321089,1.646262,45.238549,0.094613,0.065089,0.028103,0.037149,0.044331,0.045203,0.045497,-3.79563,0.072474,14497.969627,0.030759,2.623842,1.614762,0.000966,0.00026,,,0.000173,16.398506,79.486235,"""2023-10-16 17:48:35.382715"""
"""std""",1.293138,175110.130981,65.998176,0.0,91.849103,20.825982,6.01593,0.114198,0.075153,0.036163,0.065074,0.104788,0.129484,0.154903,28.169717,8.195888,184.222366,4.071804,184.129496,4.800821,184.104346,4.026672,184.924359,4.683901,19999.680397,0.0,17563.796537,8.099016,0.052642,3.548834,0.226544,5.34181,0.124721,0.057836,0.035357,0.068162,0.098707,0.101601,0.102725,37.11309,0.199761,17853.309359,0.055942,3.7644,0.220986,5.075013,0.115144,0.04097,0.030481,0.057669,0.079861,0.083302,0.084809,40.516601,0.12701,18905.652884,0.050432,3.696286,0.216915,0.068307,0.001292,,,0.001194,5.093402,5.607641,
"""min""",2282.0,0.0,2.0,630.0,0.0,0.0,22.3,0.015,0.017,0.008,0.01,0.012,0.012,0.012,-285.273,0.0,-1169.424,-1076.654,-1168.849,-1038.882,-1169.866,-1055.651,-1169.719,-999.304,0.0,781.0,0.0,1.3,0.003,0.35,0.0,22.5,0.017,0.018,0.007,0.009,0.01,0.01,0.01,-261.431,0.0,0.0,-0.039,-2.67,1.11,22.5,0.016,0.015,0.005,0.005,0.007,0.009,0.009,-174.708,0.0,0.0,-0.024,-1.5,1.07,8e-06,6e-06,0.000834,0.007157,3e-06,-2.558444,57.362963,"""2023-10-01 05:31:00"""
"""25%""",2283.0,133920.0,53.0,630.0,93.0,18.0,37.5,0.051,0.044,0.022,0.027,0.031,0.031,0.031,-17.611,0.015,25.24375,-0.074,25.251,-0.074,25.27,-0.073,24.478,-0.074,3310.0,781.0,733.0,37.8,0.014,1.421,1.47,38.7,0.05,0.044,0.021,0.026,0.029,0.029,0.029,-23.6435,0.024,760.0,0.014,1.358,1.49,41.4,0.037,0.034,0.015,0.019,0.021,0.021,0.021,-26.653,0.02,1586.0,0.01,1.025,1.47,5.5e-05,4.9e-05,0.020752,0.025202,2.9e-05,15.108257,76.729702,"""2023-10-08 23:59:00"""
"""50%""",2284.0,270240.0,132.0,630.0,142.0,32.0,42.9,0.088,0.072,0.033,0.044,0.05,0.051,0.051,-2.704,0.032,73.206,0.0,73.209,0.0,73.223,0.0,72.536,0.0,10911.0,781.0,4223.0,42.8,0.026,2.435,1.6,43.3,0.086,0.068,0.029,0.038,0.043,0.043,0.043,-0.628,0.051,4448.0,0.027,2.377,1.6,45.5,0.059,0.053,0.022,0.026,0.029,0.029,0.029,-2.351,0.042,6502.0,0.017,1.622,1.58,0.000119,9.2e-05,0.04543,0.04209,5.7e-05,16.132205,79.520606,"""2023-10-16 18:15:00"""
"""75%""",2285.0,437220.0,165.0,630.0,227.0,52.0,47.2,0.169,0.134,0.057,0.072,0.083,0.084,0.085,9.019,0.065,128.82025,0.081,128.821,0.081,128.836,0.08,128.351,0.081,24811.0,781.0,14240.0,47.1,0.057,4.719,1.76,47.1,0.153,0.117,0.048,0.062,0.076,0.077,0.076,9.467,0.108,15043.0,0.049,4.01,1.76,49.3,0.106,0.085,0.031,0.037,0.041,0.042,0.042,18.934,0.083,18530.0,0.033,2.86,1.7,0.000303,0.000237,0.090354,0.070875,0.000111,19.57806,82.836522,"""2023-10-24 11:30:00"""
"""max""",2286.0,604740.0,242.0,630.0,359.0,90.0,55.5,8.366,0.827,10.115,25.711,56.268,71.637,89.736,118.831,155.151,519.132,902.937,463.384,1357.556,513.899,1144.207,454.718,1320.804,65534.0,781.0,65534.0,62.1,1.0,49.295,3.75,54.8,17.488,0.766,1.743,4.72,7.782,8.522,11.234,183.999,128.308,65534.0,1.0,53.661,3.8,56.2,9.896,0.679,1.105,2.627,4.565,5.422,7.398,174.925,8.677,65534.0,1.0,54.486,3.81,9.92952,0.252415,7.762037,70.20322,0.161901,35.170931,98.994946,"""2023-11-01 05:30:00"""


In [5]:
class UnidirectionalLSTM(ptl.LightningModule):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        num_layers: int,
        learning_rate: float = 0.001,
        dropout: float = 0.2,
    ):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=False,
            dropout=dropout,
        )
        self.fc = nn.Linear(hidden_size, 1)
        self.learning_rate = learning_rate

    def forward(
        self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        # Replace NaNs with zeros and create a mask
        if mask is None:
            mask = ~torch.isnan(x)
        x = torch.nan_to_num(x, nan=0.0)

        # Pack the sequence to handle variable-length inputs
        lengths = mask.sum(dim=1).cpu()
        packed_x = nn.utils.rnn.pack_padded_sequence(
            x, lengths, batch_first=True, enforce_sorted=False
        )

        # Process with LSTM
        packed_lstm_out, _ = self.lstm(packed_x)

        # Unpack the sequence
        lstm_out, _ = nn.utils.rnn.pad_packed_sequence(
            packed_lstm_out, batch_first=True
        )

        # Get the last valid output for each sequence
        idx = (lengths - 1).view(-1, 1).expand(-1, lstm_out.size(2))
        last_valid = lstm_out.gather(1, idx.unsqueeze(1)).squeeze(1)

        out = self.fc(last_valid)
        return out

    def _step(
        self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    ) -> torch.Tensor:
        x, y, mask = batch
        y_hat = self(x, mask)

        # Use masked MSE loss
        loss = nn.MSELoss(reduction="none")(y_hat, y)
        mask = mask[:, -1].unsqueeze(1)  # Use mask for the last time step
        loss = (loss * mask).sum() / mask.sum()

        return loss

    def training_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        loss = self._step(batch)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int
    ) -> None:
        loss = self._step(batch)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)

    def test_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int
    ) -> None:
        loss = self._step(batch)
        loss = torch.log1p(loss)
        self.log("test_loss", loss, on_step=False, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=0.1, patience=10
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
            },
        }

In [6]:
batch_size = 128
file_path = "/home/pupperemeritus/isro_project/data/October2023.parquet"
train_loader, val_loader, test_loader, target_scaler = prepare_data_loaders(
    file_path,
    batch_size=batch_size,
    sequence_length=60,
    prediction_horizon=1,
    missing_data="interpolate",
    max_gap=1,
    test_size=0.22,
    val_size=0.11,
    stride=30,
)

In [7]:
torch.clear_autocast_cache()
torch.cuda.memory.empty_cache()
input_size = train_loader.dataset.features.shape[1] - 3 # Number of features
print(input_size)
hidden_size = 64
num_layers = 8
output_size = 1  # For regression
logger = TensorBoardLogger("tb_logs", name="gurunet_model")
epochs=100
checkpoint_callback = ModelCheckpoint(
    dirpath=f"checkpoints/version_{logger.version}",
    filename="gurunet-{epoch:02d}-{val_loss:.5f}",
    save_top_k=3,
    monitor="val_loss",
    mode="min",
    verbose=True,
)

early_stop_callback = EarlyStopping(monitor="val_loss", patience=15, mode="min")

lr_monitor = LearningRateMonitor(logging_interval="epoch", log_momentum=True)
model = UnidirectionalLSTM(input_size, hidden_size, num_layers, output_size)
torch.set_float32_matmul_precision("high")
trainer = ptl.Trainer(
    max_epochs=epochs,
    accelerator="gpu",
    devices=1,
    callbacks=[
        checkpoint_callback,
        early_stop_callback,
        lr_monitor,
    ],
    logger=logger,
    precision="16-mixed",
    enable_progress_bar=True,
    enable_checkpointing=True,
    accumulate_grad_batches=4,
    profiler="simple",
    min_epochs=50,
)
trainer.fit(model, train_loader, val_loader)

# Test the model
trainer.test(model, test_loader)

Using 16bit Automatic Mixed Precision (AMP)


62


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type   | Params | Mode 
----------------------------------------
0 | lstm | LSTM   | 265 K  | train
1 | fc   | Linear | 65     | train
----------------------------------------
265 K     Trainable params
0         Non-trainable params
265 K     Total params
1.063     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

RuntimeError: input.size(-1) must be equal to input_size. Expected 62, got 65