In [4]:
import torch
from torch import nn
from d2l import torch as d2l

class BiLSTMEncoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        # Bidirectional LSTM with batch_first=True for (batch, time_steps, features)
        self.rnn = nn.LSTM(embed_size, num_hiddens, num_layers=num_layers,
                           bidirectional=True, dropout=dropout, batch_first=True)
        
    def forward(self, X):
        # X shape: (batch_size, num_steps)
        X = self.embedding(X)  # Shape: (batch_size, num_steps, embed_size)
        # LSTM returns:
        #   output: (batch_size, num_steps, 2*num_hiddens) [concat of fwd/bwd]
        #   hidden: tuple (h_n, c_n) of final states for each direction/layer
        output, state = self.rnn(X)
        # Return output and state for sequence tasks
        return output, state

# Example usage
if __name__ == "__main__":
    # Hyperparameters
    vocab_size = 1000
    embed_size = 100
    num_hiddens = 128
    num_layers = 2
    dropout = 0.1
    batch_size = 4
    seq_len = 10
    
    # Create model
    encoder = BiLSTMEncoder(vocab_size, embed_size, num_hiddens, num_layers, dropout)
    
    # Sample input: batch of 4 sequences, each with 10 token indices
    X = torch.randint(0, vocab_size, (batch_size, seq_len))
    
    # Forward pass
    output, state = encoder(X)
    
    # Inspect shapes:
    print("Output shape:", output.shape)  # Should be [4, 10, 256] (2*128)
    
    # Extract final hidden states (h_n) from state tuple
    h_n = state[0]  # Shape: [num_layers*2, batch_size, num_hiddens]
    print("Final hidden state shape:", h_n.shape)
    
    # To get the last layer's forward & backward states:
    last_forward = h_n[-2, :, :]  # Second-to-last layer output (forward)
    last_backward = h_n[-1, :, :]  # Last layer output (backward)
    combined = torch.cat((last_forward, last_backward), dim=1)  # [batch_size, 2*num_hiddens]
    print("Combined final state shape:", combined.shape)

ModuleNotFoundError: No module named 'torchvision'