In [None]:
import torch
import torch.nn as nn

class TLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.i_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.f_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.c_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.o_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.time_gate = nn.Linear(1, hidden_size)

    def forward(self, x, h_prev, c_prev, time_delta):
        combined = torch.cat((x, h_prev), dim=1)
        i = torch.sigmoid(self.i_gate(combined))
        f = torch.sigmoid(self.f_gate(combined))
        o = torch.sigmoid(self.o_gate(combined))
        c_tilde = torch.tanh(self.c_gate(combined))
        time_delta = time_delta.unsqueeze(1)
        time_factor = torch.sigmoid(self.time_gate(time_delta))
        c_prev_decayed = c_prev * time_factor
        c = f * c_prev_decayed + i * c_tilde
        h = o * torch.tanh(c)
        return h, c

class BiTLSTMModel(nn.Module):
    def __init__(self, input_size, lstm_hidden_size, lstm_num_layers, dropout_prob=0.3):
        super().__init__()
        self.bilstm = nn.LSTM(
            input_size=input_size,
            hidden_size=lstm_hidden_size,
            num_layers=lstm_num_layers,
            batch_first=True,
            bidirectional=True
        )
        self.tlstm = TLSTMCell(lstm_hidden_size * 2, lstm_hidden_size)
        self.bn = nn.BatchNorm1d(lstm_hidden_size)
        self.fc = nn.Linear(lstm_hidden_size, 1)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x, time_deltas):
        bilstm_out, _ = self.bilstm(x)
        h = torch.zeros(x.size(0), self.tlstm.hidden_size).to(x.device)
        c = torch.zeros(x.size(0), self.tlstm.hidden_size).to(x.device)
        for t in range(x.size(1)):
            delta_t = time_deltas[:, t] if t < x.size(1) - 1 else torch.zeros(x.size(0)).to(x.device)
            h, c = self.tlstm(bilstm_out[:, t, :], h, c, delta_t)
        if h.size(0) > 1:
            h = self.bn(h)
        out = self.dropout(h)
        out = self.fc(out)
        return out
