# Dependencies and Imports


In [None]:
# !pip3 install torch pytorch-lightning

In [15]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F # ReLU, Softmax
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader, TensorDataset #, IterableDataset, TensorDataset
import torch.optim as optim
from sklearn.preprocessing import StandardScaler
import pytorch_lightning as pl

from functools import reduce
import numpy as np
import pandas as pd
import datetime as dt
from dateutil.relativedelta import relativedelta
from datetime import date, timedelta

from NowcastingPipelineM import NowcastingPH_M
import dynamicfactoranalysis as dfa

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # torch.backends.cudnn.deterministic = True                     # Only cuDNN convolution algorithms
    torch.use_deterministic_algorithms(True)                        # All torch and cuDNN algorithms when available. RunTime error if not available.
    torch.backends.cudnn.benchmark = False                          # Uses the same algorthim. May lose out on performance.
    # torch.utils.deterministic.fill_uninitialized_memory = True    # For torch.empty() or torch.Tensor.resize()
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'                 # For RNN/LSTM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.current_device())
print(torch.cuda.device(0))
print(torch.cuda.get_device_name(0))

cuda
2.6.0+cu124
True
1
0
<torch.cuda.device object at 0x796541d1c2e0>
Quadro RTX 6000


# Dataset and Dataloader preparation

In [2]:
class NowcastingLSTM_MQ(NowcastingPH_M):
    def set_classname(self, **kwargs):
        self.prefix = f'LSTM{self.kwargs.get("lag_order")} x ' + ('DFM_Opt' if self.kwargs.get("optimize_order") else f'DFM{self.kwargs.get("DFM_order")}') if self.kwargs.get("extend") else f'LSTM{self.kwargs.get("lag_order")}'
    def load_tweets(self, vintage, window, kmpair, freq='M', extend=False, **kwargs):
        vintage = pd.to_datetime(vintage)
        tweets = pd.read_csv('data/PH_Tweets_v3.csv')
        tweets['date'] = pd.to_datetime(tweets['date']) + pd.offsets.MonthEnd(0)
        tweets = tweets.set_index('date')

        if len(kmpair) == 0:
            kmpair = {keyword: list(tweets.columns.drop('keyword')) for keyword in tweets['keyword'].unique()}
        data = [tweets[tweets['keyword'] == keyword][kmpair[keyword]].add_suffix(f'_{keyword}') for keyword in kmpair.keys()]
        tweets = reduce(lambda left, right: pd.merge(left, right, on='date', how='outer', sort=True), data)
        tweets = tweets.loc[dt.datetime(2010,1,1) : pd.to_datetime(vintage), :]
        # tweets = tweets.loc[pd.to_datetime(vintage)  - relativedelta(months =  (pd.to_datetime(vintage).month - 1)%3 + window) : pd.to_datetime(vintage), :]
        # tweets = super().load_tweets(vintage, freq='M', **kwargs)
        # DFM_order = self.kwargs.get('DFM_order')                                             ### temporary measure to solve stationarity error
        # kwargs['DFM_order'] = (1, DFM_order[1], DFM_order[2], DFM_order[3])                   ### temporary measure to solve stationarity error
        tweets = self.extend_data(tweets, vintage, **kwargs) if extend else tweets
        tweets.index = pd.PeriodIndex(tweets.index, freq=freq)
        return tweets
    def extend_data(self, df, vintage, DFM_order, optimize_order=False, **kwargs):
        ### Instead of extending until year end, just extend until current vintage
        factor_order, error_order, k_factors, factor_lag = DFM_order
        # drop row if not enough non-missing (max safety)
        df = df.dropna(thresh = k_factors * (1 + factor_lag))

        if optimize_order:
            model = dfa.DynamicFactorModelOptimizer(
                endog=df, k_factors_max=k_factors, factor_lag_max=factor_lag, factor_order_max=factor_order, 
                error_order_max=error_order, verbose=True,**kwargs).fit(**kwargs)
        else:
            model = dfa.DynamicFactorModel(
                endog=df, k_factors=k_factors, factor_lag=factor_lag, factor_order=factor_order, 
                error_order=error_order, **kwargs)
        results = model.fit(disp=False, maxiter=10, method='powell', ftol=1e-3, **kwargs)
        # results = model.fit(disp=False, maxiter=1000, method='powell', ftol=1e-5, **kwargs)
        
        df_extended = pd.DataFrame()
        for col in df.columns:
            col_extended = pd.concat([df[[col]].dropna(), 
                                    results.predict(start=df[col].dropna().index[-1], end=vintage)[[col]].iloc[1:]])
            df_extended = pd.concat([df_extended, col_extended], axis=1)
        df_extended.index.name = df.index.name

        return df_extended


In [20]:
target = 'GDP'
kmpair = {'PE': ['CRVADER_BVN','CR_BxP_0'],'PU+': ['CRVADER_BVN','CR_BxP_0']} # kmpair = {'PE':['CR_B0']}
window = 1000
extend = False
DFM_order = (1,0,1,0)
model = NowcastingLSTM_MQ(extend = extend, DFM_order=DFM_order, optimize_order = False, kmpair=kmpair, target=target)
data, target_scaler, econ_scaler, tweets_scaler = model.load_data(vintage=pd.to_datetime('2017-03-31'),window=window, kmpair=kmpair, with_econ=True, with_tweets=False, target_release_lag=True,scaled=True)
data = data.loc[pd.to_datetime('2010-01-31'):,:].dropna()#.reset_index()
# data[:] = StandardScaler().fit_transform(data)
# model.load_target(vintage=pd.to_datetime('2017-01-31'), target='GDP', growth=True, quarterly=True, freq='M', target_release_lag=False)
display(data)

Unnamed: 0_level_0,target,ECN.pseiclose_YoY,ECN.tbill_usd90,ECN.phpusd_YoY,ECN.tdr_php360,ECN.govtexpt_YoY,ECN.m1_YoY,ECN.IPI_YoY,ECN.imports_YoY,ECN.exports_YoY
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
2010-01,1.357160,2.546622,-0.321012,-0.573727,1.246838,0.367732,0.475311,2.979768,1.602402,2.201924
2010-02,1.357160,2.590565,0.039815,-0.610389,1.325077,-0.503924,1.288223,2.606936,1.288057,2.202541
2010-03,1.357160,2.395650,0.256312,-1.218834,1.340969,1.809788,2.249231,1.844254,2.177334,2.291599
2010-04,1.243454,2.235894,0.256312,-1.601437,1.092804,0.932956,2.944650,2.429373,2.950015,1.345291
2010-05,1.243454,1.117557,0.256312,-0.897047,1.169821,1.344380,1.642571,1.916994,1.696690,1.882029
...,...,...,...,...,...,...,...,...,...,...
2016-08,0.739029,-0.451752,1.483126,0.188067,-0.745817,-0.072413,0.438996,0.715950,0.394573,-0.578610
2016-09,0.739029,-0.395784,1.122298,0.247170,-0.813054,1.441119,-0.273999,0.502828,0.545799,0.043630
2016-10,0.413168,-0.791660,1.555291,0.836004,-0.732369,-1.544243,-0.280197,0.362225,0.018322,0.144558
2016-11,0.413168,-1.131016,2.565608,0.894927,-0.627236,1.417577,-1.012744,1.103489,0.778714,-0.749637


In [13]:
data

Unnamed: 0_level_0,target,ECN.pseiclose_YoY,ECN.tbill_usd90,ECN.phpusd_YoY,ECN.tdr_php360,ECN.govtexpt_YoY,ECN.m1_YoY,ECN.IPI_YoY,ECN.imports_YoY,ECN.exports_YoY
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
2010-01,1.357160,2.546622,-0.321012,-0.573727,1.246838,0.367732,0.475311,2.979768,1.602402,2.201924
2010-02,1.357160,2.590565,0.039815,-0.610389,1.325077,-0.503924,1.288223,2.606936,1.288057,2.202541
2010-03,1.357160,2.395650,0.256312,-1.218834,1.340969,1.809788,2.249231,1.844254,2.177334,2.291599
2010-04,1.243454,2.235894,0.256312,-1.601437,1.092804,0.932956,2.944650,2.429373,2.950015,1.345291
2010-05,1.243454,1.117557,0.256312,-0.897047,1.169821,1.344380,1.642571,1.916994,1.696690,1.882029
...,...,...,...,...,...,...,...,...,...,...
2016-08,0.739029,-0.451752,1.483126,0.188067,-0.745817,-0.072413,0.438996,0.715950,0.394573,-0.578610
2016-09,0.739029,-0.395784,1.122298,0.247170,-0.813054,1.441119,-0.273999,0.502828,0.545799,0.043630
2016-10,0.413168,-0.791660,1.555291,0.836004,-0.732369,-1.544243,-0.280197,0.362225,0.018322,0.144558
2016-11,0.413168,-1.131016,2.565608,0.894927,-0.627236,1.417577,-1.012744,1.103489,0.778714,-0.749637


In [4]:
def sliding_windows(data_x, seq_length, freq_ratio = 3):
    '''
    seq_length = number of rows of historical data + number of rows of input data (i.e. 12 + 3)
    '''
    x_encoder_in = []
    y_decoder_in = []
    y_target = []

    for i in range(0,len(data_x)-seq_length+1, freq_ratio):
        _x_in = data.iloc[i:(i+seq_length),1:]
        _y_in = data_x.iloc[i+freq_ratio-1:i+seq_length-freq_ratio:freq_ratio, :1] # gets every rth row but stops before the current low-freq vintage
        _y_out = data_x.iloc[i+seq_length-1, :1]
        x_encoder_in.append(_x_in)
        y_decoder_in.append(_y_in)
        y_target.append(_y_out)
    
    return np.array(x_encoder_in),np.array(y_decoder_in), np.array(y_target)
x_encoder_in, y_decoder_in, y_target = sliding_windows(data, seq_length=15)
print(x_encoder_in.shape) # (timesteps, seq_length, dim_x)
print(y_decoder_in.shape) # (timesteps, dim_y)
print(y_target.shape)
trainX_in = Variable(torch.Tensor(x_encoder_in))
trainY_in = Variable(torch.Tensor(y_decoder_in))
trainY_out = Variable(torch.Tensor(y_target))

(24, 15, 9)
(24, 4, 1)
(24, 1)


# Model Definition

In [5]:
### seq2seq y
class PreAttnEncoder(nn.Module):
    def __init__(self, dim_x, n_a, dropout_rate=0.2, bidirectional_encoder=False):
        super(PreAttnEncoder, self).__init__()
        self.bidirectional_encoder = bidirectional_encoder
        self.lstm = nn.LSTM(dim_x, n_a, batch_first=True, bidirectional=bidirectional_encoder)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        a, _ = self.lstm(x)
        a = self.dropout(a)
        return a

class OneStepAttn(nn.Module):
    def __init__(self, n_a, n_s, n_align):
        super(OneStepAttn, self).__init__()
        self.densor1 = nn.Linear(n_a + n_s, n_align)
        self.densor2 = nn.Linear(n_align, 1)

    def forward(self, a, s_prev):
        s_prev = s_prev.unsqueeze(1).repeat(1, a.size(1), 1) # (batch_size, Lx, n_s)
        concat = torch.cat((a, s_prev), dim=-1) # (batch_size, Lx, n_a + n_s)
        e = torch.tanh(self.densor1(concat))
        energies = F.relu(self.densor2(e))  # (batch_size, Lx, 1)
        alphas = F.softmax(energies, dim=1) # (batch_size, Lx, 1)
        context = torch.bmm(alphas.transpose(1, 2), a).squeeze(1) # (batch_size, n_a)
        return context

class MTMFSeq2Seq(nn.Module):
    def __init__(self, dim_x, dim_y, Lx, Ty, n_a, n_s, n_align_y, fc_y, dropout_rate, freq_ratio=3, bidirectional_encoder=False, l1reg=1e-5, l2reg=1e-4):
        super(MTMFSeq2Seq, self).__init__()
        self.dim_x = dim_x
        self.dim_y = dim_y
        self.Lx = Lx
        self.Ty = Ty
        self.n_a = n_a
        self.n_s = n_s
        self.n_align_y = n_align_y
        self.fc_y = fc_y
        self.freq_ratio = freq_ratio
        self.bidirectional_encoder = bidirectional_encoder
        self.l1reg = l1reg
        self.l2reg = l2reg

        self.pre_attn = PreAttnEncoder(dim_x, n_a, dropout_rate, bidirectional_encoder)
        self.one_step_attention_y = OneStepAttn(n_a, n_s, n_align_y)

        self.post_attn_y = nn.LSTMCell(n_a + dim_y, n_s)
        self.ffn1_y = nn.Linear(n_s, fc_y)
        self.dropout_fn_y = nn.Dropout(dropout_rate)
        self.ffn2_y = nn.Linear(fc_y, dim_y) ### add regularizer

    def initialize_state(self, batch_size, dim, device):
        return torch.zeros(batch_size, dim, device=device)

    def forward(self, x_encoder_in, y_decoder_in):
        batch_size = x_encoder_in.size(0)

        a = self.pre_attn(x_encoder_in)

        s_y, c_y = self.initialize_state(batch_size, self.n_s, device), self.initialize_state(batch_size, self.n_s, device)
        for t in range(self.Ty):
            a_idx = int((t + 1) * self.freq_ratio - 1)
            a_to_attend = a[:, (a_idx - self.freq_ratio + 1):(a_idx + 1), :]
            context = self.one_step_attention_y(a_to_attend, s_y)
            post_attn_input = torch.cat((context, y_decoder_in[:, t, :]), dim=-1)
            s_y, c_y = self.post_attn_y(post_attn_input, (s_y, c_y))

        y_pred = self.ffn1_y(s_y)
        y_pred = self.dropout_fn_y(y_pred)
        y_pred = self.ffn2_y(y_pred)

        return y_pred

In [None]:
class MTMFSeq2SeqLightning(pl.LightningModule):
    def __init__(self, model, learning_rate=0.001):
        super(MTMFSeq2SeqLightning, self).__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.criterion = torch.nn.MSELoss()

    def forward(self, x_encoder_in, y_decoder_in):
        return self.model(x_encoder_in, y_decoder_in)

    def training_step(self, batch, batch_idx):
        x_encoder_in, y_decoder_in, y_target = batch
        y_pred = self(x_encoder_in, y_decoder_in)
        loss = self.criterion(y_pred, y_target)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

# Prepare the dataset and dataloader
train_dataset = TensorDataset(trainX_in, trainY_in, trainY_out)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=8)

# Initialize the model
dim_x = trainX_in.shape[-1] # 9
dim_y = trainY_in.shape[-1] # 1
Lx = trainX_in.shape[1] # 15
Ty = trainY_in.shape[1] # 4
n_a = 128
n_s = 256
n_align_y = 16
fc_y = 128
dropout_rate = 0.4
freq_ratio = 3
bidirectional_encoder = False

model = MTMFSeq2Seq(
    dim_x=dim_x,
    dim_y=dim_y,
    Lx=Lx,
    Ty=Ty,
    n_a=n_a,
    n_s=n_s,
    n_align_y=n_align_y,
    fc_y=fc_y,
    dropout_rate=dropout_rate,
    freq_ratio=freq_ratio,
    bidirectional_encoder=bidirectional_encoder,
)

# Initialize the Lightning module
lightning_model = MTMFSeq2SeqLightning(model)

# Train the model
trainer = pl.Trainer(max_epochs=100, accelerator='auto',log_every_n_steps=2)
trainer.fit(lightning_model, train_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



  | Name      | Type        | Params | Mode 
--------------------------------------------------
0 | model     | MTMFSeq2Seq | 506 K  | train
1 | criterion | MSELoss     | 0      | train
--------------------------------------------------
506 K     Trainable params
0         Non-trainable params
506 K     Total params
2.027     Total estimated model params size (MB)
12        Modules in train mode
0         Modules in eval mode


Epoch 99: 100%|██████████| 12/12 [00:00<00:00, 24.64it/s, v_num=30, train_loss=0.0272]  

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 12/12 [00:00<00:00, 22.70it/s, v_num=30, train_loss=0.0272]


In [7]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/

In [None]:
# ### seq2one y
# class PreAttnEncoder(nn.Module):
#     """Pre-attention Encoder module"""
#     def __init__(self, dim_x, n_a, dropout_rate=0.2, bidirectional_encoder=False):
#         super(PreAttnEncoder, self).__init__()
#         self.bidirectional_encoder = bidirectional_encoder
#         self.lstm = nn.LSTM( input_size=dim_x, hidden_size=n_a, batch_first=True, bidirectional=bidirectional_encoder)
#         self.dropout = nn.Dropout(dropout_rate)

#     def forward(self, x):
#         a, _ = self.lstm(x)
#         a = self.dropout(a)
#         return a


# class OneStepAttn(nn.Module):
#     """Attention alignment module"""
#     def __init__(self, n_a, n_s, n_align):
#         super(OneStepAttn, self).__init__()
#         self.densor1 = nn.Linear(n_a + n_s, n_align)
#         self.densor2 = nn.Linear(n_align, 1)

#     def forward(self, a, s_prev):
#         s_prev = s_prev.unsqueeze(1).repeat(1, a.size(1), 1)
#         concat = torch.cat((a, s_prev), dim=-1)
#         e = torch.tanh(self.densor1(concat))
#         energies = F.relu(self.densor2(e))
#         alphas = F.softmax(energies, dim=1)
#         context = torch.sum(alphas * a, dim=1)
#         return context


# class MTMFSeq2One(nn.Module):
#     def __init__(
#         self,
#         Lx,
#         dim_x,
#         Ty,
#         dim_y,
#         n_a,
#         n_s,
#         n_align,
#         fc_y,
#         dropout_rate,
#         freq_ratio=3,
#         bidirectional_encoder=False
#     ):
#         super(MTMFSeq2One, self).__init__()
#         self.Lx = Lx
#         self.Ty = Ty
#         self.dim_x = dim_x
#         self.dim_y = dim_y
#         self.n_a = n_a
#         self.n_s = n_s
#         self.n_align = n_align
#         self.fc_y = fc_y
#         self.freq_ratio = freq_ratio
#         self.bidirectional_encoder = bidirectional_encoder

#         # Encoder
#         self.pre_attn = PreAttnEncoder(dim_x, n_a, dropout_rate, bidirectional_encoder)

#         # Attention alignment model
#         self.one_step_attention = OneStepAttn(n_a, n_s, n_align)

#         # Decoder
#         self.post_attn = nn.LSTMCell(input_size=n_a + dim_y, hidden_size=n_s)
#         self.ffn1 = nn.Linear(n_s, fc_y)
#         self.dropout = nn.Dropout(dropout_rate)
#         self.ffn2 = nn.Linear(fc_y, dim_y) ### add regularizer

#     def initialize_state(self, batch_size, dim):
#         return torch.zeros(batch_size, dim)

#     def forward(self, batch_inputs):
#         x, y = batch_inputs
#         batch_size = x.size(0)

#         # Stage 1: Pre-attention encoding
#         a = self.pre_attn(x)

#         # Stage 2: Attention-based decoding
#         s = self.initialize_state(batch_size, self.n_s).to(x.device)
#         c = self.initialize_state(batch_size, self.n_s).to(x.device)

#         for t in range(self.Ty):
#             a_idx = int((t + 1) * self.freq_ratio - 1)
#             a_to_attend = a[:, (a_idx - self.freq_ratio + 1):(a_idx + 1), :]
#             context = self.one_step_attention(a_to_attend, s)

#             post_attn_input = torch.cat((context, y[:, t, :].unsqueeze(1)), dim=-1)
#             s, c = self.post_attn(post_attn_input.squeeze(1), (s, c))

#         y_pred = F.relu(self.ffn1(s))
#         y_pred = self.dropout(y_pred)
#         y_pred = self.ffn2(y_pred)

#         return y_pred

In [None]:
# ### seq2one x, y

# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class PreAttnEncoder(nn.Module):
#     """Pre-attention Encoder module"""
#     def __init__(self, dim_x, fc_dim, n_a, dropout_rate=0.2, bidirectional_encoder=False, l1reg=1e-5, l2reg=1e-4):
#         super(PreAttnEncoder, self).__init__()
#         self.bidirectional_encoder = bidirectional_encoder
#         self.l1reg = l1reg
#         self.l2reg = l2reg

#         self.lstm = nn.LSTM(input_size=dim_x, hidden_size=n_a, batch_first=True, bidirectional=bidirectional_encoder)
#         self.ffn1 = nn.Linear(n_a * (2 if bidirectional_encoder else 1), fc_dim)
#         self.dropout = nn.Dropout(dropout_rate)
#         self.ffn2 = nn.Linear(fc_dim, dim_x) ### add regularizer

#     def forward(self, x):
#         """
#         Forward pass for the encoder
#         Args:
#             x: (tensor) shape (batch_size, Lx, dim_x)
#         Returns:
#             x_pred: (tensor), the next step prediction for x (batch_size, dim_x)
#             a: (tensor) sequence of LSTM hidden states, (batch_size, Lx, n_a)
#         """
#         a, _ = self.lstm(x)
#         x_pred = F.relu(self.ffn1(a[:, -1, :]))
#         x_pred = self.dropout(x_pred)
#         x_pred = self.ffn2(x_pred)
#         return x_pred, a


# class OneStepAttn(nn.Module):
#     """Attention alignment module"""
#     def __init__(self, n_a, n_s, n_align):
#         super(OneStepAttn, self).__init__()
#         self.densor1 = nn.Linear(n_a + n_s, n_align)
#         self.densor2 = nn.Linear(n_align, 1)

#     def forward(self, a, s_prev):
#         """
#         Performs one step of attention
#         Args:
#             a: hidden state from the pre-attention LSTM, shape = (batch_size, Lx, n_a)
#             s_prev: previous hidden state of the post-attention LSTM, shape = (batch_size, n_s)
#         Returns:
#             context: context vector, input of the next post-attention LSTM cell
#         """
#         s_prev = s_prev.unsqueeze(1).repeat(1, a.size(1), 1)  # (batch_size, Lx, n_s)
#         concat = torch.cat((a, s_prev), dim=-1)  # (batch_size, Lx, n_a + n_s)
#         e = torch.tanh(self.densor1(concat))
#         energies = F.relu(self.densor2(e))  # (batch_size, Lx, 1)
#         alphas = F.softmax(energies, dim=1)  # (batch_size, Lx, 1)
#         context = torch.bmm(alphas.transpose(1,2), a).squeeze(1)  # (batch_size, n_a)
#         return context


# class MTMFSeq2One(nn.Module):
#     def __init__( self, Lx, dim_x, Ty, dim_y, n_a, n_s, n_align, fc_x, fc_y, dropout_rate, freq_ratio=3, bidirectional_encoder=False, l1reg=1e-5, l2reg=1e-4):
#         super(MTMFSeq2One, self).__init__()
#         self.dim_x = dim_x
#         self.dim_y = dim_y
#         self.Lx = Lx
#         self.Ty = Ty
#         self.n_a = n_a
#         self.n_s = n_s
#         self.n_align = n_align
#         self.fc_x = fc_x
#         self.fc_y = fc_y
#         self.freq_ratio = freq_ratio
#         self.bidirectional_encoder = bidirectional_encoder
#         self.l1reg = l1reg
#         self.l2reg = l2reg

#         # Encoder
#         self.pre_attn = PreAttnEncoder(dim_x, fc_x, n_a, dropout_rate, bidirectional_encoder, l1reg, l2reg)
#         # Attention alignment model
#         self.one_step_attention = OneStepAttn(n_a, n_s, n_align)
#         # Decoder
#         self.post_attn = nn.LSTMCell(n_a + dim_y, n_s)
#         self.ffn1 = nn.Linear(n_s, fc_y)
#         self.dropout = nn.Dropout(dropout_rate)
#         self.ffn2 = nn.Linear(fc_y, dim_y) ### add regularizer

#     def initialize_state(self, batch_size, dim):
#         return torch.zeros(batch_size, dim)

#     def forward(self, batch_inputs):
#         """
#         Forward pass
#         Args:
#             batch_inputs: tuple of (x, y) where
#                 x: encoder input, shape (batch_size, Lx, dim_x)
#                 y: decoder input, shape (batch_size, Ty, dim_y)
#         Returns:
#             x_pred: prediction for x
#             y_pred: prediction for y
#         """
#         x, y = batch_inputs
#         batch_size = x.size(0)

#         # Stage 1: Pre-attention encoding
#         x_pred, a = self.pre_attn(x)

#         # Stage 2: Attention-based decoding
#         s = self.initialize_state(batch_size, self.n_s).to(x.device)
#         c = self.initialize_state(batch_size, self.n_s).to(x.device)
#         for t in range(self.Ty):
#             a_idx = int((t + 1) * self.freq_ratio - 1)
#             a_to_attend = a[:, (a_idx - self.freq_ratio + 1):(a_idx + 1), :]
#             context = self.one_step_attention(a_to_attend, s)
#             post_attn_input = torch.cat((context, y[:, t, :]), dim=-1)
#             s, c = self.post_attn(post_attn_input, (s, c))

#         y_pred = F.relu(self.ffn1(s))
#         y_pred = self.dropout(y_pred)
#         y_pred = self.ffn2(y_pred)
#         return x_pred, y_pred