In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence
import spacy
import logging
import os
from typing import Tuple, List, Dict
import torch.nn.functional as F
from tqdm import tqdm

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
# Configuration class for hyperparameters and settings
class Config:
    """Configuration for data processing and Seq2Seq model."""
    def __init__(self):
        self.batch_size = 32
        self.max_seq_len = 50
        self.embedding_dim = 128
        self.hidden_dim = 256
        self.num_layers = 2
        self.dropout = 0.1
        self.attention_dim = 128
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model_save_path = 'seq2seq_model.pth'
        self.log_file = 'data_processing.log'

In [3]:

# Setup logging
logging.basicConfig(filename=Config().log_file, level=logging.INFO, 
                    format='%(asctime)s - %(levelname)s - %(message)s')


In [4]:
# Load SpaCy for tokenization
try:
    spacy_en = spacy.load('en_core_web_sm')
except OSError as e:
    logging.error(f"Failed to load SpaCy model: {e}")
    raise

def tokenize_text(text: str) -> List[str]:
    """Tokenize text using SpaCy."""
    return [tok.text.lower() for tok in spacy_en.tokenizer(text)]

In [5]:
# Dataset loading and preprocessing
def load_and_preprocess_data(dataset_name: str = "pubmed_qa", split: str = "train", max_seq_len: int = 50) -> Tuple[DataLoader, Dict[str, int], Dict[str, int]]:
    """Load and preprocess a Hugging Face dataset for Seq2Seq training.

    Args:
        dataset_name (str): Name of the Hugging Face dataset (e.g., 'pubmed_qa').
        split (str): Dataset split ('train', 'validation', 'test').
        max_seq_len (int): Maximum sequence length for source and target.

    Returns:
        Tuple[DataLoader, Dict[str, int], Dict[str, int]]: DataLoader, source vocabulary, target vocabulary.
    """
    try:
        # Load dataset from Hugging Face
        dataset = load_dataset(dataset_name, 'pqa_labeled', split=split)
        logging.info(f"Loaded dataset {dataset_name} ({split} split) with {len(dataset)} samples")

        # Build vocabularies
        def yield_tokens(data, key):
            for sample in data:
                yield tokenize_text(sample[key])

        src_vocab = build_vocab_from_iterator(yield_tokens(dataset, 'question'), 
                                             specials=['<unk>', '<pad>', '<sos>', '<eos>'], min_freq=2)
        tgt_vocab = build_vocab_from_iterator(yield_tokens(dataset, 'answer'), 
                                             specials=['<unk>', '<pad>', '<sos>', '<eos>'], min_freq=2)
        src_vocab.set_default_index(src_vocab['<unk>'])
        tgt_vocab.set_default_index(tgt_vocab['<unk>'])
        logging.info(f"Source vocab size: {len(src_vocab)}, Target vocab size: {len(tgt_vocab)}")

        # Collate function for DataLoader
        def collate_fn(batch):
            src_batch, tgt_batch = [], []
            for sample in batch:
                src_tokens = ['<sos>'] + tokenize_text(sample['question'])[:max_seq_len - 2] + ['<eos>']
                tgt_tokens = ['<sos>'] + tokenize_text(sample['answer'])[:max_seq_len - 2] + ['<eos>']
                src_batch.append(torch.tensor([src_vocab[token] for token in src_tokens]))
                tgt_batch.append(torch.tensor([tgt_vocab[token] for token in tgt_tokens]))
            src_batch = pad_sequence(src_batch, padding_value=src_vocab['<pad>'], batch_first=True)
            tgt_batch = pad_sequence(tgt_batch, padding_value=tgt_vocab['<pad>'], batch_first=True)
            return src_batch, tgt_batch

        dataloader = DataLoader(dataset, batch_size=Config().batch_size, collate_fn=collate_fn)
        return dataloader, src_vocab, tgt_vocab

    except Exception as e:
        logging.error(f"Failed to load or preprocess dataset: {e}")
        raise

In [6]:
# Placeholder for image preprocessing (for multimodal extension)
def preprocess_image_data(image_paths: List[str], processor=None) -> torch.Tensor:
    """Placeholder for preprocessing medical images (e.g., CheXpert).

    Args:
        image_paths (List[str]): List of image file paths.
        processor: Image processor (e.g., from Hugging Face Transformers).

    Returns:
        torch.Tensor: Processed image tensors.
    """
    try:
        # Example: Use a processor like CLIPProcessor for MedGemma-like models
        if processor is None:
            raise ValueError("Image processor not provided for multimodal data")
        # Implement image loading and processing here (e.g., using PIL or torchvision)
        # For now, return a dummy tensor
        logging.info("Image preprocessing placeholder executed")
        return torch.zeros(len(image_paths), 3, 224, 224)  # Dummy tensor (batch, channels, height, width)
    except Exception as e:
        logging.error(f"Image preprocessing failed: {e}")
        raise

In [7]:
!nvidia-smi

Tue Jul 22 12:23:32 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.247.01             Driver Version: 535.247.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        Off | 00000000:01:00.0 Off |                  Off |
|  0%   31C    P8              21W / 480W |    157MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [8]:
import torch

print("CUDA Available:", torch.cuda.is_available())
print("Device Count:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("Device Name:", torch.cuda.get_device_name(0))
else:
    print("CUDA is NOT available. Check driver or GPU.")


CUDA Available: False
Device Count: 1
CUDA is NOT available. Check driver or GPU.


In [2]:
import torch
import torch.nn as nn
from typing import Tuple

class BiLSTMEncoder(nn.Module): # nn is a class in pytorch that provides base class for all neural network. 
    """Bidirectional LSTM encoder for sequence-to-sequence models.

    Encodes input sequences into hidden states and outputs, suitable for tasks like machine translation.
    Uses a bidirectional LSTM to capture context from both directions, followed by a linear layer for output projection.

    Args:
        vocab_size (int): Size of the input vocabulary.
        embedding_dim (int): Dimension of token embeddings.
        hidden_dim (int): Dimension of LSTM hidden states per direction.
        num_layers (int): Number of LSTM layers.
        dropout (float): Dropout probability for regularization.
        output_dim (int): Dimension of the output (e.g., target vocabulary size for classification).

    Attributes:
        embedding (nn.Embedding): Token embedding layer.
        lstm (nn.LSTM): Bidirectional LSTM layer.
        dropout (nn.Dropout): Dropout layer.
        fc (nn.Linear): Linear layer to project concatenated hidden states.
    """
    def __init__(self, vocab_size: int, embedding_dim: int, hidden_dim: int, #constsructor to initialize layrs and args . 
                 num_layers: int, dropout: float, output_dim: int):
        super(BiLSTMEncoder, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.output_dim = output_dim
        self.dropout_rate = dropout

        # Initialize layers
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True, #input shape (batch_size,sequence_length, embedding_dim) by default shape of lstm is sequence_length first.)
            bidirectional=True, 
            dropout=dropout if num_layers > 1 else 0.0  # Dropout only if multiple layers
        )
        #Pytorch is case sensetive. All module names uses upper case letter.
        #like nn.LSTM,nn.RNN
        #input dimension= embedding_dim
        #both ht and ct are of same size(num_layers*num_directions,batch_size,hidden_dim) 
        #number of direction is 2 for bidirectional lstm 
  
        #number of layers=num_layers is the number of stacked LSTM layers (vertical depth).
        #make sure layers are capitalized while calling, we will not use small letter like nn.LSTM or nn.Linear not nn.linear 
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)  #as we are uing bidirectional LSTM, hidden_dim *2 

        # Initialize weights for stability  
        nn.init.xavier_uniform_(self.embedding.weight)  #xavier initialization for embedding.
        for name, param in self.lstm.named_parameters(): #xavier initialization for LSTM weights
            #name parameters means all the parameters of LSTM gates.
            if 'weight' in name:
                nn.init.xavier_uniform_(param)
            elif 'bias' in name:
                nn.init.zeros_(param)
        nn.init.xavier_uniform_(self.fc.weight)
        nn.init.zeros_(self.fc.bias)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Performs the forward pass of the encoder(real calculation happens).

        Args:
            x (torch.Tensor): Input token IDs, shape (batch_size, sequence_length).

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
                - output: LSTM outputs, shape (batch_size, sequence_length, hidden_dim * 2).
                - hidden: Final hidden states, shape (num_layers * 2, batch_size, hidden_dim).
                - cell: Final cell states, shape (num_layers * 2, batch_size, hidden_dim).

        Raises:
            ValueError: If input tensor shape or dimensions are invalid.
        """
        # Validate input
        if x.dim() != 2:
            raise ValueError(f"Expected input shape (batch_size, sequence_length), got {x.shape}")
        if not torch.all(x >= 0) or not torch.all(x < self.vocab_size):
            raise ValueError(f"Input token IDs must be in [0, {self.vocab_size}), got min {x.min()}, max {x.max()}")

        # Embed input: (batch_size, sequence_length) -> (batch_size, sequence_length, embedding_dim)
        embedded = self.embedding(x)

        # Apply LSTM: (batch_size, sequence_length, embedding_dim) -> 
        # (batch_size, sequence_length, hidden_dim * 2), (num_layers * 2, batch_size, hidden_dim)
        output, (hidden, cell) = self.lstm(embedded)

        # Concatenate final forward and backward hidden states: (batch_size, hidden_dim * 2)
        final_hidden = torch.cat((hidden[-2], hidden[-1]), dim=1)

        # Apply dropout and linear layer: (batch_size, hidden_dim * 2) -> (batch_size, output_dim)
        dropout = self.dropout(final_hidden)
        out = self.fc(dropout)

        return out, hidden, cell

In [10]:


class AdditiveAttention(nn.Module):
    """Implements Bahdanau-style additive attention for sequence-to-sequence models.

    Computes attention scores between the decoder's hidden state and encoder outputs,
    producing a context vector and attention weights for use in decoding.

    Args:
        encoder_hidden_dim (int): Dimension of the encoder's hidden states.
        decoder_hidden_dim (int): Dimension of the decoder's hidden states.
        attention_dim (int): Dimension of the attention mechanism's hidden layer.

    Attributes:
        encoder_attn (nn.Linear): Linear layer to project encoder outputs.
        decoder_attn (nn.Linear): Linear layer to project decoder hidden state.
        v (nn.Parameter): Parameter vector to compute attention scores.
    """
    def __init__(self, encoder_hidden_dim: int, decoder_hidden_dim: int, attention_dim: int):
        super(AdditiveAttention, self).__init__()
        self.encoder_hidden_dim = encoder_hidden_dim
        self.decoder_hidden_dim = decoder_hidden_dim
        self.attention_dim = attention_dim

        # Linear layers to project encoder and decoder states
        self.encoder_attn = nn.Linear(encoder_hidden_dim, attention_dim, bias=False)
        self.decoder_attn = nn.Linear(decoder_hidden_dim, attention_dim, bias=False)
        
        # Attention score parameter, initialized with Glorot initialization for stability
        self.v = nn.Parameter(torch.empty(attention_dim))
        nn.init.xavier_uniform_(self.v.unsqueeze(0))  # Shape: (1, attention_dim)

    def forward(self, encoder_outputs: torch.Tensor, decoder_hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Computes the attention context vector and weights.

        Args:
            encoder_outputs (torch.Tensor): Encoder hidden states, shape (batch_size, src_len, encoder_hidden_dim).
            decoder_hidden (torch.Tensor): Decoder hidden state, shape (batch_size, decoder_hidden_dim).

        Returns:
            Tuple[torch.Tensor, torch.Tensor]:
                - context: Context vector, shape (batch_size, encoder_hidden_dim).
                - attn_weights: Attention weights, shape (batch_size, src_len).

        Raises:
            ValueError: If input tensor shapes or dimensions do not match expected values.
        """
        # Validate input shapes
        if encoder_outputs.dim() != 3:
            raise ValueError(
                f"Expected encoder_outputs to be 3D (batch_size, src_len, encoder_hidden_dim), got {encoder_outputs.shape}"
            )
        if decoder_hidden.dim() != 2:
            raise ValueError(
                f"Expected decoder_hidden to be 2D (batch_size, decoder_hidden_dim), got {decoder_hidden.shape}"
            )

        batch_size, src_len, enc_dim = encoder_outputs.size()
        if enc_dim != self.encoder_hidden_dim:
            raise ValueError(f"Encoder hidden dimension mismatch: expected {self.encoder_hidden_dim}, got {enc_dim}")
        if decoder_hidden.size(1) != self.decoder_hidden_dim:
            raise ValueError(f"Decoder hidden dimension mismatch: expected {self.decoder_hidden_dim}, got {decoder_hidden.size(1)}")
        if batch_size != decoder_hidden.size(0):
            raise ValueError(f"Batch size mismatch: encoder_outputs {batch_size}, decoder_hidden {decoder_hidden.size(0)}")

        # Repeat decoder hidden state to match source length: (batch_size, decoder_hidden_dim) -> (batch_size, src_len, decoder_hidden_dim)
        decoder_hidden = decoder_hidden.unsqueeze(1).repeat(1, src_len, 1)

        # Compute energy: (batch_size, src_len, encoder_hidden_dim) -> (batch_size, src_len, attention_dim)
        #                + (batch_size, src_len, decoder_hidden_dim) -> (batch_size, src_len, attention_dim)
        energy = torch.tanh(self.encoder_attn(encoder_outputs) + self.decoder_attn(decoder_hidden))

        # Compute attention scores: (batch_size, src_len, attention_dim) @ (attention_dim,) -> (batch_size, src_len)
        attention_scores = torch.matmul(energy, self.v)

        # Apply softmax to get attention weights: (batch_size, src_len)
        attn_weights = F.softmax(attention_scores, dim=1)

        # Compute context vector: (batch_size, 1, src_len) @ (batch_size, src_len, encoder_hidden_dim) -> (batch_size, 1, encoder_hidden_dim)
        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)

        return context, attn_weights

In [11]:
class DecoderWithAttention(nn.Module):
    """Decoder with additive attention for sequence-to-sequence models.

    Processes one token at a time, using attention to focus on relevant encoder outputs.
    Outputs predictions for the next token and updated LSTM states.

    Args:
        output_dim (int): Size of the target vocabulary.
        embed_dim (int): Dimension of token embeddings.
        encoder_hidden_dim (int): Dimension of encoder hidden states.
        decoder_hidden_dim (int): Dimension of decoder hidden states.
        attention_dim (int): Dimension of the attention mechanism's hidden layer.
        dropout (float, optional): Dropout probability. Defaults to 0.1.

    Attributes:
        embedding (nn.Embedding): Token embedding layer.
        attention (AdditiveAttention): Attention mechanism.
        rnn (nn.LSTM): LSTM layer for decoding.
        fc_out (nn.Linear): Output projection layer.
        dropout (nn.Dropout): Dropout layer.
    """
    def __init__(self, output_dim: int, embed_dim: int, encoder_hidden_dim: int, 
                 decoder_hidden_dim: int, attention_dim: int, dropout: float = 0.1):
        super(DecoderWithAttention, self).__init__()
        self.output_dim = output_dim
        self.embed_dim = embed_dim
        self.encoder_hidden_dim = encoder_hidden_dim
        self.decoder_hidden_dim = decoder_hidden_dim
        self.attention_dim = attention_dim

        # Initialize layers
        self.embedding = nn.Embedding(output_dim, embed_dim)
        self.attention = AdditiveAttention(encoder_hidden_dim, decoder_hidden_dim, attention_dim)
        self.rnn = nn.LSTM(embed_dim + encoder_hidden_dim, decoder_hidden_dim, batch_first=True)
        self.fc_out = nn.Linear(decoder_hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_token: torch.Tensor, decoder_hidden: torch.Tensor, 
                decoder_cell: torch.Tensor, encoder_outputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Performs one decoding step.

        Args:
            input_token (torch.Tensor): Input token IDs, shape (batch_size, 1).
            decoder_hidden (torch.Tensor): Previous hidden state, shape (1, batch_size, decoder_hidden_dim).
            decoder_cell (torch.Tensor): Previous cell state, shape (1, batch_size, decoder_hidden_dim).
            encoder_outputs (torch.Tensor): Encoder outputs, shape (batch_size, src_len, encoder_hidden_dim).

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
                - prediction: Logits for next token, shape (batch_size, output_dim).
                - hidden: Updated hidden state, shape (1, batch_size, decoder_hidden_dim).
                - cell: Updated cell state, shape (1, batch_size, decoder_hidden_dim).
                - attn_weights: Attention weights, shape (batch_size, src_len).

        Raises:
            ValueError: If input shapes do not match expected dimensions.
        """
        # Validate input shapes
        if input_token.dim() != 2 or input_token.size(1) != 1:
            raise ValueError(f"Expected input_token shape (batch_size, 1), got {input_token.shape}")
        if decoder_hidden.dim() != 3 or decoder_hidden.size(0) != 1:
            raise ValueError(f"Expected decoder_hidden shape (1, batch_size, decoder_hidden_dim), got {decoder_hidden.shape}")
        if decoder_cell.shape != decoder_hidden.shape:
            raise ValueError(f"Expected decoder_cell shape to match decoder_hidden, got {decoder_cell.shape}")
        if encoder_outputs.dim() != 3:
            raise ValueError(f"Expected encoder_outputs shape (batch_size, src_len, encoder_hidden_dim), got {encoder_outputs.shape}")

        batch_size = input_token.size(0)

        # Embed input token: (batch_size, 1) -> (batch_size, 1, embed_dim)
        embedded = self.dropout(self.embedding(input_token))

        # Compute attention: (batch_size, src_len, encoder_hidden_dim), (batch_size, decoder_hidden_dim)
        # -> (batch_size, encoder_hidden_dim), (batch_size, src_len)
        context, attn_weights = self.attention(encoder_outputs, decoder_hidden.squeeze(0))
        context = context.unsqueeze(1)  # (batch_size, 1, encoder_hidden_dim)

        # Concatenate embedding and context: (batch_size, 1, embed_dim + encoder_hidden_dim)
        rnn_input = torch.cat((embedded, context), dim=2)

        # LSTM: (batch_size, 1, embed_dim + encoder_hidden_dim) -> (batch_size, 1, decoder_hidden_dim)
        output, (hidden, cell) = self.rnn(rnn_input, (decoder_hidden, decoder_cell))

        # Predict next token: (batch_size, 1, decoder_hidden_dim) -> (batch_size, output_dim)
        prediction = self.fc_out(output.squeeze(1))

        return prediction, hidden, cell, attn_weights

In [12]:
# Seq2Seq Model
class Seq2Seq(nn.Module):
    """Sequence-to-Sequence model combining encoder and decoder."""
    def __init__(self, encoder: nn.Module, decoder: nn.Module):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src: torch.Tensor, tgt: torch.Tensor, teacher_forcing_ratio: float = 0.5) -> torch.Tensor:
        batch_size, tgt_len = tgt.size()
        outputs = torch.zeros(batch_size, tgt_len, self.decoder.output_dim).to(src.device)
        encoder_outputs, hidden, cell = self.encoder(src)
        hidden = torch.cat((hidden[-2], hidden[-1]), dim=1).unsqueeze(0)
        cell = torch.cat((cell[-2], cell[-1]), dim=1).unsqueeze(0)
        input_token = tgt[:, 0:1]

        for t in range(1, tgt_len):
            output, hidden, cell, _ = self.decoder(input_token, hidden, cell, encoder_outputs)
            outputs[:, t] = output
            input_token = tgt[:, t:t+1] if torch.rand(1).item() < teacher_forcing_ratio else output.argmax(1, keepdim=True)
        return outputs

FOR epoch in range(num_epochs): 
    SET model to training mode
    INIT loss to 0
    FOR each batch (src, tgt) in dataloader:
        Move data to device
        Zero out gradients
        Forward pass through model → Get output
        Calculate loss using output and target
        Backpropagate
        Clip gradients
        Update weights using optimizer
        Track loss
    END
    Log average loss
SAVE model


In [14]:
def train_model(model: nn.Module, dataloader: DataLoader, optimizer: optim.Optimizer, 
                criterion: nn.Module, num_epochs: int, device: torch.device, save_path: str):
    """Train the Seq2Seq model with progress tracking and logging."""
    try:
        model.to(device)  # Move model to GPU/CPU
        for epoch in range(num_epochs):
           model.train()  # Set to training mode (enables dropout)
           epoch_loss = 0
           with tqdm(total=len(dataloader), desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch') as pbar:
                for src, tgt in dataloader:  # Iterate over batches
                    src, tgt = src.to(device), tgt.to(device)  # Move data to device
                    optimizer.zero_grad()  # Clear previous gradients
                    output = model(src, tgt, Config().teacher_forcing_ratio)  # Forward pass
                    loss = criterion(output[:, 1:].reshape(-1, model.decoder.output_dim), 
                                   tgt[:, 1:].reshape(-1))  # Compute loss
                    loss.backward()  # Backpropagate
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Clip gradients
                    optimizer.step()  # Update weights
                    epoch_loss += loss.item()  # Track loss
                    pbar.update(1)  # Update progress bar
                    pbar.set_postfix(loss=loss.item())  # Show current loss
            avg_loss = epoch_loss / len(dataloader)  # Average loss per batch
            logging.info(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}')  # Log result
            print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}')  # Print result
        torch.save(model.state_dict(), save_path)  # Save model weights
        logging.info(f'Model saved to {save_path}')  # Log save action
    except Exception as e:
        logging.error(f"Training failed: {e}")  # Log errors
        raise  # Re-raise for debugging

IndentationError: unindent does not match any outer indentation level (<tokenize>, line 22)

In [None]:
# Inference function for translation
def translate_sentence(model: nn.Module, sentence: str, src_vocab, tgt_vocab, 
                      device: torch.device, max_length: int = 50) -> str:
    """Translate a sentence using the trained Seq2Seq model."""
    try:
        model.eval()
        tokens = ['<sos>'] + tokenize_en(sentence) + ['<eos>']
        src_tensor = torch.tensor([src_vocab.get(token, src_vocab['<unk>']) for token in tokens], 
                                device=device).unsqueeze(0)
        with torch.no_grad():
            encoder_outputs, hidden, cell = model.encoder(src_tensor)
            hidden = torch.cat((hidden[-2], hidden[-1]), dim=1).unsqueeze(0)
            cell = torch.cat((cell[-2], cell[-1]), dim=1).unsqueeze(0)
            input_token = torch.tensor([tgt_vocab['<sos>']], device=device).unsqueeze(0)
            outputs = []
            for _ in range(max_length):
                output, hidden, cell, _ = model.decoder(input_token, hidden, cell, encoder_outputs)
                pred_token = output.argmax(1).item()
                if pred_token == tgt_vocab['<eos>']:
                    break
                outputs.append(pred_token)
                input_token = torch.tensor([[pred_token]], device=device)
        translation = ' '.join([tgt_vocab.get_itos()[idx] for idx in outputs])
        return translation
    except Exception as e:
        logging.error(f"Inference failed: {e}")

In [None]:
# Command-line interface
def parse_args():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(description='Seq2Seq Model Training and Inference')
    parser.add_argument('--train', action='store_true', help='Train the model')
    parser.add_argument('--infer', type=str, help='Translate a sentence')
    parser.add_argument('--load_model', type=str, default=None, help='Path to pre-trained model')
    return parser.parse_args()

# Main execution
if __name__ == "__main__":
    args = parse_args()
    config = Config()
    
    # Load data
    train_dataloader, src_vocab, tgt_vocab = get_dataloader(split='train', 
                                                          batch_size=config.batch_size, 
                                                          max_seq_len=config.max_seq_len)
    
    # Model instantiation
    encoder = BiLSTMEncoder(
        vocab_size=len(src_vocab),
        embedding_dim=config.embedding_dim,
        hidden_dim=config.hidden_dim,
        num_layers=config.num_layers,
        dropout=config.dropout
    )
    decoder = DecoderWithAttention(
        output_dim=len(tgt_vocab),
        embed_dim=config.embedding_dim,
        encoder_hidden_dim=config.hidden_dim * 2,  # Bidirectional
        decoder_hidden_dim=config.hidden_dim * 2,  # Match encoder's output
        attention_dim=config.attention_dim,
        dropout=config.dropout
    )
    model = Seq2Seq(encoder, decoder)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=tgt_vocab['<pad>'])
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
    
    # Load pre-trained model if specified
    if args.load_model and os.path.exists(args.load_model):
        model.load_state_dict(torch.load(args.load_model))
        logging.info(f"Loaded pre-trained model from {args.load_model}")
    
    # Training
    if args.train:
        train_model(model, train_dataloader, optimizer, criterion, 
                   config.num_epochs, config.device, config.model_save_path)
    
    # Inference
    if args.infer:
        if not os.path.exists(config.model_save_path) and not args.load_model:
            raise FileNotFoundError("No trained model found. Please train the model first or specify a pre-trained model path.")
        if args.load_model:
            model.load_state_dict(torch.load(args.load_model))
        else:
            model.load_state_dict(torch.load(config.model_save_path))
        translation = translate_sentence(model, args.infer, src_vocab, tgt_vocab, config.device)
        print(f'Translation: {translation}')

NameError: name 'argparse' is not defined