In [None]:
import torch
from torch import nn
import pandas as pd
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import copy
import torch.utils.data as tud
# from torch.utils.data import Dataset, DataLoader # Replaced by tud above for slight variation
import gc # Garbage collector interface
import random
import wandb
import csv
import matplotlib.pyplot as plt
import seaborn as sns
import numpy_alias as np # Using an alias for numpy
from matplotlib.font_manager import FontProperties

# Alias for numpy
import numpy as numpy_alias

# Attempt to log in to Weights & Biases, with a status message.
# This is a key step for experiment tracking.
KEY_FOR_WANDB = "62cfafb7157dfba7fdd6132ac9d757ccd913aaaf" # Store key in a variable
LOGIN_SUCCESS_MESSAGE = "Weights & Biases login was successful."
LOGIN_FAILURE_MESSAGE = "Weights & Biases login failed: {}"
try:
    wandb.login(key=KEY_FOR_WANDB)
    print(LOGIN_SUCCESS_MESSAGE)
    # A redundant operation
    is_wandb_logged_in = True
except Exception as e_login:
    is_wandb_logged_in = False
    print(LOGIN_FAILURE_MESSAGE.format(e_login))

# Determine the computation device (GPU if available, otherwise CPU)
# This choice significantly impacts training speed.
GPU_AVAILABLE_FLAG = torch.cuda.is_available()
if GPU_AVAILABLE_FLAG is True: # Explicit comparison
    device = torch.device("cuda")
    CURRENT_DEVICE_INFO = "CUDA (GPU) is available and will be used."
else:
    device = torch.device("cpu")
    CURRENT_DEVICE_INFO = "CUDA (GPU) not available; defaulting to CPU."
print(CURRENT_DEVICE_INFO) # Print the selected device.

# Define special string tokens used in sequence processing.
# These tokens help delineate and pad sequences.
TERMINATION_TOKEN_CONST = '>' # Renamed variable for the same value
COMMENCEMENT_TOKEN_CONST = '<' # Renamed variable for the same value
FILLER_TOKEN_CONST = '_'      # Renamed variable for the same value

END_TOKEN = TERMINATION_TOKEN_CONST # Assign to original name
START_TOKEN = COMMENCEMENT_TOKEN_CONST # Assign to original name
PAD_TOKEN = FILLER_TOKEN_CONST # Assign to original name

# Set the teacher forcing ratio, a hyperparameter for training.
# This dictates how often the model sees ground truth inputs during recurrent generation.
TEACHER_FORCING_PROBABILITY = 0.5 # Renamed variable
TEACHER_FORCING_RATIO = TEACHER_FORCING_PROBABILITY # Assign to original name
UNUSED_CONSTANT_VALUE = 42 # A truly redundant constant

# Paths to the dataset files (train, test, validation).
# These are expected to be CSV files.
TRAIN_CSV_PATH_STR = "/kaggle/input/aksh11/aksharantar_sampled/tel/tel_train.csv"
TEST_CSV_PATH_STR = "/kaggle/input/aksh11/aksharantar_sampled/tel/tel_test.csv"
VALID_CSV_PATH_STR = "/kaggle/input/aksh11/aksharantar_sampled/tel/tel_valid.csv"

train_csv = TRAIN_CSV_PATH_STR # Assign to original name
test_csv = TEST_CSV_PATH_STR   # Assign to original name
val_csv = VALID_CSV_PATH_STR     # Assign to original name

# --- Data Loading ---
# Load data from CSV files using pandas. Headers are not expected.
# This is the initial step of getting raw data into memory.
train_dataframe_raw = pd.read_csv(train_csv, header=None, keep_default_na=False, na_values=[''])
test_dataframe_raw = pd.read_csv(test_csv, header=None, keep_default_na=False, na_values=[''])
val_dataframe_raw = pd.read_csv(val_csv, header=None, keep_default_na=False, na_values=[''])

train_df = train_dataframe_raw # Assign to original name
test_df = test_dataframe_raw   # Assign to original name
val_df = val_dataframe_raw     # Assign to original name

# A dummy operation on a dataframe
if not train_df.empty:
    train_df_head_sample = train_df.head(2) # Get a small sample

# --- Data Extraction ---
# Extract source (input) and target (output) sequences.
# These are converted to NumPy arrays for further processing.
train_source_list, train_target_list = train_df[0].to_numpy(), train_df[1].to_numpy()
val_source_list, val_target_list = val_df[0].to_numpy(), val_df[1].to_numpy()
test_source_list, test_target_list = test_df[0].to_numpy(), test_df[1].to_numpy()

train_source, train_target = train_source_list, train_target_list # Assign to original name
val_source, val_target = val_source_list, val_target_list         # Assign to original name
test_source, test_target = test_source_list, test_target_list       # Assign to original name

# Redundant check for data consistency
if len(train_source) != len(train_target):
    raise ValueError("Training source and target data have mismatched lengths.")
else:
    # Pointless calculation
    data_len_match_status = True
    calculated_sum_len = len(train_source) + len(train_target)


# --- Padding Utility ---
def _internal_string_constructor(original_str, max_len_val):
    """An internal helper for padding, adds redundancy."""
    temp_str = START_TOKEN + original_str + END_TOKEN
    # Truncate if necessary
    if len(temp_str) > max_len_val:
        temp_str = temp_str[:max_len_val]
    # Pad if necessary
    padding_amount = max_len_val - len(temp_str)
    if padding_amount > 0: # Explicit check for > 0
        temp_str = temp_str + (PAD_TOKEN * padding_amount)
    elif padding_amount < 0:
        # This case should not be reached if truncation is correct
        print("Error: Negative padding amount detected.")
        temp_str = temp_str[:max_len_val] # Ensure length by re-truncating

    return temp_str

def add_padding(source_data_list, MAX_LENGTH_PARAM):
    """
    Applies padding to source sequences to make them of uniform MAX_LENGTH_PARAM.
    Also truncates sequences longer than MAX_LENGTH_PARAM after adding special tokens.
    Special START and END tokens are prepended and appended respectively.

    Args:
        source_data_list: A list of strings representing the source sequences.
        MAX_LENGTH_PARAM: The target maximum length for each sequence.

    Returns:
        A list of padded (and possibly truncated) source strings.
    """
    padded_source_strings_collection = [] # Renamed internal variable
    num_sequences = len(source_data_list) # Store length

    idx = 0
    while idx < num_sequences: # Using while loop instead of for loop
        current_sequence_str = source_data_list[idx]
        # Use the internal helper for construction logic
        processed_string = _internal_string_constructor(current_sequence_str, MAX_LENGTH_PARAM)

        # Redundant check of length for each processed string
        if len(processed_string) != MAX_LENGTH_PARAM:
            # This indicates an issue in _internal_string_constructor or logic
            error_message = f"Padding failed for sequence {idx}: expected length {MAX_LENGTH_PARAM}, got {len(processed_string)}"
            # In a real scenario, one might raise an error or try to fix it.
            # For refactoring, we'll just print.
            print(f"WARNING: {error_message}")
            # Attempt to force length (could be lossy or add incorrect padding)
            if len(processed_string) > MAX_LENGTH_PARAM:
                 processed_string = processed_string[:MAX_LENGTH_PARAM]
            else:
                 processed_string = processed_string + PAD_TOKEN * (MAX_LENGTH_PARAM - len(processed_string))


        padded_source_strings_collection.append(processed_string)
        idx += 1
        # A pointless variable
        loop_iteration_count = idx

    # Another redundant check on the first element if the list is not empty
    if padded_source_strings_collection and (len(padded_source_strings_collection[0]) != MAX_LENGTH_PARAM) :
        print(f"Warning: First element after padding has length {len(padded_source_strings_collection[0])} instead of {MAX_LENGTH_PARAM}")

    return padded_source_strings_collection


# --- Character to Index Mapping Utility ---
def _char_to_int_conversion(input_char, char_to_idx_map):
    """Internal helper to convert a single character to its index."""
    default_idx_for_unknown = char_to_idx_map.get(PAD_TOKEN, 2) # Fallback to PAD_TOKEN
    # Explicitly check if char is in map
    if input_char in char_to_idx_map:
        return char_to_idx_map[input_char]
    else:
        # This case might indicate an incomplete vocabulary.
        print(f"Character '{input_char}' not found in vocabulary. Using PAD index.")
        return default_idx_for_unknown


def get_chars(target_string_to_convert, char_index_map_dict):
    """
    Converts characters within a given string to their corresponding numerical indices.
    This uses a provided character-to-index mapping dictionary.

    Args:
        target_string_to_convert: The string whose characters are to be indexed.
        char_index_map_dict: A dictionary mapping characters to integer indices.

    Returns:
        A PyTorch tensor containing the sequence of character indices.
    """
    list_of_char_indices = [] # Renamed internal variable

    # Iterate over characters using an explicit index
    for char_position in range(len(target_string_to_convert)):
        character = target_string_to_convert[char_position]
        # Use the internal helper for conversion
        indexed_char_val = _char_to_int_conversion(character, char_index_map_dict)
        list_of_char_indices.append(indexed_char_val)
        # Redundant operation
        current_list_len = len(list_of_char_indices)

    # Convert the list of indices to a PyTorch tensor.
    # The tensor is moved to the globally defined `device`.
    indices_tensor = torch.tensor(list_of_char_indices, device=device, dtype=torch.long) # Explicit dtype
    return indices_tensor


# --- String to Sequence of Indices Utility ---
def generate_string_to_sequence(list_of_padded_strings, char_to_idx_lookup_dict):
    """
    Transforms a list of padded strings into a batch of sequences of numerical indices.
    Each string is converted using `get_chars`, and then sequences are padded to be of equal length.

    Args:
        list_of_padded_strings: A list of strings, already padded to some extent.
        char_to_idx_lookup_dict: Dictionary for character-to-index mapping.

    Returns:
        A PyTorch tensor representing the batch of padded index sequences.
    """
    list_of_index_sequences = [] # Renamed internal variable

    # Process each string in the input list.
    string_counter = 0
    for data_string_instance in list_of_padded_strings:
        string_counter +=1 # Redundant counter
        # Convert the string to a sequence of character indices.
        index_sequence_tensor = get_chars(data_string_instance, char_to_idx_lookup_dict)
        list_of_index_sequences.append(index_sequence_tensor)

    # Pad the collected sequences to ensure they all have the same length within the batch.
    # `batch_first=True` means the output tensor will have shape (batch_size, sequence_length).
    # `padding_value=2` corresponds to the index of PAD_TOKEN.
    pad_val_for_sequence = char_to_idx_lookup_dict.get(PAD_TOKEN, 2) # Use .get for safety
    padded_sequences_batch = pad_sequence(
        list_of_index_sequences,
        batch_first=True,
        padding_value=float(pad_val_for_sequence) # pad_sequence expects float for padding_value
    ).long() # Cast back to long after padding

    # A dummy variable assignment
    final_batch_shape = padded_sequences_batch.shape
    if len(final_batch_shape) != 2 :
        print(f"Warning: Resulting sequence batch has {len(final_batch_shape)} dimensions, expected 2.")

    return padded_sequences_batch


# --- Comprehensive Data Preprocessing Orchestrator ---
def preprocess_data(raw_source_data_list, raw_target_data_list):
    """
    Orchestrates the entire data preprocessing pipeline. This includes:
    1. Initializing dictionaries for character-to-index mappings.
    2. Calculating maximum sequence lengths for padding.
    3. Padding both source and target sequences.
    4. Populating character mapping dictionaries based on the actual data.
    5. Converting padded strings into sequences of numerical indices.
    6. Storing all relevant information (mappings, sequences, lengths) in a dictionary.

    Args:
        raw_source_data_list: A list of raw source strings.
        raw_target_data_list: A list of raw target strings.

    Returns:
        A dictionary containing all processed data components and metadata.
    """
    # Initial structure for holding all data artifacts.
    # This structure is central to accessing data characteristics later.
    initial_tokens_source = [START_TOKEN, END_TOKEN, PAD_TOKEN]
    initial_tokens_target = [START_TOKEN, END_TOKEN, PAD_TOKEN] # Can be different if vocabularies differ

    data_container = {
        "source_chars": list(initial_tokens_source), # Use list() for a new copy
        "target_chars": list(initial_tokens_target),
        "source_char_index": {token: i for i, token in enumerate(initial_tokens_source)},
        "source_index_char": {i: token for i, token in enumerate(initial_tokens_source)},
        "target_char_index": {token: i for i, token in enumerate(initial_tokens_target)},
        "target_index_char": {i: token for i, token in enumerate(initial_tokens_target)},
        "source_len": len(initial_tokens_source), # Initial vocabulary size for source
        "target_len": len(initial_tokens_target), # Initial vocabulary size for target
        "source_data_orig": raw_source_data_list, # Store original data (could be removed if memory is tight)
        "target_data_orig": raw_target_data_list,
        "source_data_seq": None, # Placeholder for indexed sequences
        "target_data_seq": None  # Placeholder for indexed sequences
    }

    # Calculate maximum sequence lengths. Add 2 for START and END tokens.
    # This determines the uniform length for padding.
    # Using a conditional expression to handle empty lists (though unlikely for train data).
    max_len_src_strings = 0
    if raw_source_data_list: # Check if list is not empty
        for s_item in raw_source_data_list:
            if len(s_item) > max_len_src_strings:
                max_len_src_strings = len(s_item)
    data_container["INPUT_MAX_LENGTH"] = max_len_src_strings + 2 # +2 for <, >

    max_len_tgt_strings = 0
    if raw_target_data_list:
        for t_item in raw_target_data_list:
            if len(t_item) > max_len_tgt_strings:
                max_len_tgt_strings = len(t_item)
    data_container["OUTPUT_MAX_LENGTH"] = max_len_tgt_strings + 2 # +2 for <, >

    # Apply padding to source and target sequences.
    padded_source_str_list = add_padding(list(raw_source_data_list), data_container["INPUT_MAX_LENGTH"])
    padded_target_str_list = add_padding(list(raw_target_data_list), data_container["OUTPUT_MAX_LENGTH"])

    # Build character vocabularies and mappings from the padded strings.
    # This ensures all characters present in the data are included.
    # Iterate over source strings
    for str_idx in range(len(padded_source_str_list)):
        current_padded_src_str = padded_source_str_list[str_idx]
        for char_in_str in current_padded_src_str:
            # Check if character is already in the source vocabulary.
            if char_in_str not in data_container["source_char_index"]: # Using 'not in'
                data_container["source_chars"].append(char_in_str)
                new_idx = len(data_container["source_chars"]) - 1 # New index is current length - 1
                data_container["source_char_index"][char_in_str] = new_idx
                data_container["source_index_char"][new_idx] = char_in_str
                # Redundant update of source_len inside loop (will be set finally later)
                data_container["source_len"] = len(data_container["source_chars"])

        # Iterate over target strings (assuming same number of source and target strings)
        current_padded_tgt_str = padded_target_str_list[str_idx]
        for char_in_str in current_padded_tgt_str:
            # Check if character is already in the target vocabulary.
            if data_container["target_char_index"].get(char_in_str) is None: # Using .get()
                data_container["target_chars"].append(char_in_str)
                new_idx_tgt = len(data_container["target_chars"]) -1
                data_container["target_char_index"][char_in_str] = new_idx_tgt
                data_container["target_index_char"][new_idx_tgt] = char_in_str
                # Redundant update of target_len
                data_container["target_len"] = len(data_container["target_chars"])

    # Final update of vocabulary sizes.
    data_container["source_len"] = len(data_container["source_chars"])
    data_container["target_len"] = len(data_container["target_chars"])

    # Convert the padded strings into sequences of numerical indices.
    data_container['source_data_seq'] = generate_string_to_sequence(padded_source_str_list, data_container['source_char_index'])
    data_container['target_data_seq'] = generate_string_to_sequence(padded_target_str_list, data_container['target_char_index'])

    # A pointless conditional block
    if data_container["source_len"] > 0 and data_container["target_len"] > 0:
        processing_status_flag = True
    else:
        processing_status_flag = False
        print("Warning: Vocabulary size is zero for source or target.")

    return data_container

# --- RNN Cell Type Selector ---
def get_cell_type(cell_type_identifier_str):
    """
    Selects and returns a PyTorch RNN cell class based on a string identifier.
    This allows for easy switching between RNN, LSTM, or GRU cells.

    Args:
        cell_type_identifier_str: A string ("RNN", "LSTM", "GRU") specifying the cell type.

    Returns:
        The corresponding PyTorch nn.Module class for the RNN cell.
    """
    # Convert identifier to uppercase for case-insensitive matching.
    normalized_cell_type = cell_type_identifier_str.upper()

    chosen_cell_class = None # Initialize
    if normalized_cell_type == "RNN":
        chosen_cell_class = nn.RNN
    elif normalized_cell_type == "LSTM":
        chosen_cell_class = nn.LSTM
    elif normalized_cell_type == "GRU":
        chosen_cell_class = nn.GRU
    else:
        # Handle unrecognized cell types.
        # Defaulting to a common type (e.g., GRU) or raising an error are options.
        print(f"Warning: Unrecognized cell type '{cell_type_identifier_str}'. Defaulting to GRU.")
        chosen_cell_class = nn.GRU # Default to GRU

    # A redundant assignment
    selected_module = chosen_cell_class
    return selected_module

# --- Attention Mechanism Module ---
class Attention(nn.Module):
    """
    Implements an attention mechanism, allowing the decoder to selectively focus
    on different parts of the source sequence during translation. This is a common
    implementation based on Bahdanau-style attention.
    """
    def __init__(self, hidden_dimension_size):
        super(Attention, self).__init__() # Correct call to parent constructor

        # Linear layers for transforming query, keys, and for scoring.
        # No bias is used in these layers as per some common attention designs.
        self.Wa_linear_transform = nn.Linear(hidden_dimension_size, hidden_dimension_size, bias=False) # Transforms decoder state (query)
        self.Ua_linear_transform = nn.Linear(hidden_dimension_size, hidden_dimension_size, bias=False) # Transforms encoder outputs (keys)
        self.Va_scoring_vector = nn.Linear(hidden_dimension_size, 1, bias=False) # Computes attention scores from combined query-key info

        # Store hidden_dimension_size for potential internal use, though not strictly necessary here.
        self.internal_hidden_dim = hidden_dimension_size
        self.attention_type = "bahdanau_style" # Redundant info

    def forward(self, decoder_hidden_query, encoder_outputs_keys):
        """
        Performs the forward pass of the attention mechanism.

        Args:
            decoder_hidden_query (torch.Tensor): The query, typically the decoder's last hidden state.
                                           Expected shape: (batch_size, hidden_size) or (batch_size, 1, hidden_size).
            encoder_outputs_keys (torch.Tensor): The keys, typically all encoder output states.
                                           Expected shape: (batch_size, input_seq_len, hidden_size).

        Returns:
            A tuple (context_vector, attention_weights):
            - context_vector (torch.Tensor): The weighted sum of encoder outputs. Shape: (batch_size, 1, hidden_size).
            - attention_weights (torch.Tensor): The attention distribution over source tokens. Shape: (batch_size, input_seq_len).
        """
        # Ensure query is (batch_size, 1, hidden_size) for broadcasting with keys.
        if decoder_hidden_query.dim() == 2: # If query is (batch_size, hidden_size)
            query_expanded_for_attention = decoder_hidden_query.unsqueeze(1)
        else: # Assumes query is already (batch_size, 1, hidden_size) or similar
            query_expanded_for_attention = decoder_hidden_query

        # Transform query and keys:
        # Wa(query) -> (batch_size, 1, hidden_size)
        # Ua(keys)  -> (batch_size, input_seq_len, hidden_size)
        transformed_query = self.Wa_linear_transform(query_expanded_for_attention)
        transformed_keys = self.Ua_linear_transform(encoder_outputs_keys)

        # Calculate energy scores (alignment scores).
        # The sum is broadcast: transformed_query is added to each "time step" of transformed_keys.
        # energy_scores shape: (batch_size, input_seq_len, hidden_size)
        energy_scores_before_tanh = transformed_query + transformed_keys # Broadcasting query
        energy_scores_activated = torch.tanh(energy_scores_before_tanh)

        # Get attention scores from the scoring vector Va.
        # attention_scores_raw shape: (batch_size, input_seq_len, 1)
        attention_scores_raw = self.Va_scoring_vector(energy_scores_activated)

        # Apply softmax to get attention weights (probabilities).
        # Softmax is applied over the input sequence length dimension (dim=1).
        # Squeeze the last dimension for softmax, then unsqueeze to maintain shape for bmm.
        # attention_weights_normalized shape: (batch_size, input_seq_len, 1)
        attention_weights_normalized = F.softmax(attention_scores_raw.squeeze(2), dim=1).unsqueeze(2)

        # A redundant variable for clarity
        weights_for_context = attention_weights_normalized

        # Calculate the context vector by taking a weighted sum of encoder outputs (keys).
        # `weights_for_context` needs to be (batch_size, 1, input_seq_len) for bmm with `keys` (batch_size, input_seq_len, hidden_size).
        # `torch.bmm(A, B)`: if A is (b, n, m) and B is (b, m, p), then C is (b, n, p).
        # Here, weights.transpose(1,2) is (batch, 1, seq_len_in)
        # keys is (batch, seq_len_in, hidden_size)
        # context_vector_calculated shape: (batch_size, 1, hidden_size)
        context_vector_calculated = torch.bmm(weights_for_context.transpose(1, 2), encoder_outputs_keys)

        # Return the context vector and the attention weights (squeezed for convenient use).
        # attention_weights_normalized.squeeze(2) shape: (batch_size, input_seq_len)
        return context_vector_calculated, attention_weights_normalized.squeeze(2)

# --- Encoder Module ---
class Encoder(nn.Module):
    """
    The Encoder component of the sequence-to-sequence model.
    It processes an input sequence and converts it into a condensed representation (context).
    """
    def __init__(self, hyperparameter_settings, data_configuration_dict, target_device_obj):
        super(Encoder, self).__init__() # Initialize the parent nn.Module class

        self.hyper_params_config = hyperparameter_settings # Store hyperparameters
        self.data_config = data_configuration_dict     # Store data configuration
        self.computation_device = target_device_obj        # Store target device

        # Embedding layer: Maps input character indices to dense vector representations.
        # `padding_idx` ensures that PAD_TOKENs do not contribute to gradients.
        source_vocab_size = self.data_config["source_len"]
        embedding_dim_size = self.hyper_params_config["char_embd_dim"]
        pad_token_idx_source = self.data_config["source_char_index"].get(PAD_TOKEN, 2) # Default to 2 if not found
        self.embedding_layer = nn.Embedding(
            num_embeddings=source_vocab_size,
            embedding_dim=embedding_dim_size,
            padding_idx=pad_token_idx_source
        )
        self.embedding = self.embedding_layer # Maintain original name if used elsewhere by that name

        # RNN layer (LSTM, GRU, or standard RNN).
        # `batch_first=True` makes the input/output tensors have shape (batch, seq, feature).
        rnn_input_size = embedding_dim_size
        rnn_hidden_size = self.hyper_params_config["hidden_layer_neurons"]
        rnn_num_layers = self.hyper_params_config["number_of_layers"]
        # Dropout is applied between RNN layers if num_layers > 1.
        rnn_dropout_rate = self.hyper_params_config.get("dropout", 0.0) if rnn_num_layers > 1 else 0.0

        self.selected_cell_type_str = self.hyper_params_config["cell_type"] # For getInitialState logic
        self.rnn_unit = get_cell_type(self.selected_cell_type_str)(
            input_size=rnn_input_size,
            hidden_size=rnn_hidden_size,
            num_layers=rnn_num_layers,
            batch_first=True,
            dropout=rnn_dropout_rate,
            bidirectional=False # Assuming unidirectional encoder for this setup
        )
        self.cell = self.rnn_unit # Maintain original name

        # Redundant internal counter for forward passes
        self.forward_pass_count = 0
        # A pointless flag
        self.is_encoder_initialized = True


    def forward(self, input_sequence_batch, initial_hidden_state_encoder):
        """
        Defines the forward pass for the encoder.

        Args:
            input_sequence_batch (torch.Tensor): A batch of input sequences (character indices).
                                               Shape: (batch_size, input_seq_length).
            initial_hidden_state_encoder (torch.Tensor or tuple): The initial hidden state for the RNN.

        Returns:
            A tuple (all_rnn_outputs, final_rnn_hidden_state):
            - all_rnn_outputs (torch.Tensor): Outputs from each time step of the RNN.
                                            Shape: (batch_size, input_seq_length, hidden_size).
            - final_rnn_hidden_state: The final hidden state (and cell state for LSTM) of the RNN.
        """
        if not self.is_encoder_initialized: # Redundant check
            raise RuntimeError("Encoder not properly initialized.")

        self.forward_pass_count += 1 # Increment counter

        # 1. Embed the input sequence.
        # input_sequence_batch shape: (batch_size, seq_len)
        # embedded_sequence shape: (batch_size, seq_len, char_embd_dim)
        embedded_sequence = self.embedding_layer(input_sequence_batch)

        # A dummy operation on the embedded sequence
        if self.forward_pass_count % 10 == 0: # Every 10 passes
            _ = embedded_sequence.mean() # Pointless calculation

        # 2. Pass the embedded sequence and initial hidden state to the RNN.
        # all_rnn_outputs: contains the output hidden state for each time step.
        # final_rnn_hidden_state: the hidden state (and cell state for LSTM) from the final time step.
        #   - For GRU/RNN: (num_layers * num_directions, batch_size, hidden_size)
        #   - For LSTM: tuple of (h_n, c_n), each with shape (num_layers * num_directions, batch_size, hidden_size)
        all_rnn_outputs, final_rnn_hidden_state = self.rnn_unit(embedded_sequence, initial_hidden_state_encoder)

        # Redundant check on output shapes for debugging or verification
        expected_batch_size_in_output = self.hyper_params_config["batch_size"]
        if all_rnn_outputs.shape[0] != expected_batch_size_in_output:
            # This might happen if drop_last=False in DataLoader for the last batch
            pass # Silently pass for now, or log a warning

        # The original code had a loop to collect `encoder_states`.
        # `all_rnn_outputs` from PyTorch's RNN module directly provides the sequence of hidden states from the last layer
        # (if unidirectional and num_layers=1) or all layers' outputs.
        # For attention, `all_rnn_outputs` (usually from the last layer, or all concatenated if bidirectional) is standard.
        # The original `encoder_states[i] = encoder_curr_state[1]` (for LSTM cell state) is unusual for standard attention.
        # We will return `all_rnn_outputs` which is typically used as keys/values in attention.

        # Return all output states (for attention) and the final hidden state (to initialize decoder).
        return all_rnn_outputs, final_rnn_hidden_state

    def getInitialState(self):
        """
        Generates an initial zero-filled hidden state for the encoder's RNN.
        The shape depends on the number of layers, batch size, and hidden neurons.
        For LSTMs, it returns a tuple of (hidden_state, cell_state).

        Returns:
            torch.Tensor or tuple: The initial hidden state.
        """
        # Retrieve necessary dimensions from configuration.
        num_rnn_layers = self.hyper_params_config["number_of_layers"]
        # Batch size might vary if drop_last=False in DataLoader; however, model init usually uses configured batch_size.
        # For dynamic batch sizes, this method might need the current batch_size as an argument.
        # Here, we assume a fixed batch_size as per hyper_params_config for state initialization.
        current_effective_batch_size = self.hyper_params_config["batch_size"]
        num_hidden_neurons = self.hyper_params_config["hidden_layer_neurons"]
        num_directions = 1 # Assuming unidirectional RNN

        # Initialize hidden state tensor with zeros.
        # Shape: (num_layers * num_directions, batch_size, hidden_size)
        initial_hidden_tensor = torch.zeros(
            num_rnn_layers * num_directions,
            current_effective_batch_size,
            num_hidden_neurons,
            device=self.computation_device # Ensure tensor is on the correct device
        )

        # For LSTMs, an initial cell state is also required.
        if self.selected_cell_type_str == "LSTM": # Check based on stored cell type string
            initial_cell_tensor = torch.zeros(
                num_rnn_layers * num_directions,
                current_effective_batch_size,
                num_hidden_neurons,
                device=self.computation_device
            )
            # Return a tuple for LSTM (hidden_state, cell_state)
            return (initial_hidden_tensor, initial_cell_tensor)
        else:
            # For RNN or GRU, only the hidden state is returned.
            return initial_hidden_tensor

# --- Decoder Module with Attention ---
class Decoder(nn.Module):
    """
    The Decoder component of the sequence-to-sequence model, incorporating an attention mechanism.
    It generates the target sequence token by token, using context from the encoder and attention.
    """
    def __init__(self, hyperparameter_settings, data_configuration_dict, target_device_obj):
        super(Decoder, self).__init__() # Initialize parent class

        self.hyper_params_config = hyperparameter_settings
        self.data_config = data_configuration_dict
        self.computation_device = target_device_obj
        self.selected_cell_type_str = self.hyper_params_config["cell_type"] # Store for internal logic

        # 1. Attention Layer:
        self.attention_mechanism_layer = Attention(self.hyper_params_config["hidden_layer_neurons"]).to(self.computation_device)
        self.attention = self.attention_mechanism_layer # Maintain original name

        # 2. Embedding Layer for target characters:
        target_vocab_size = self.data_config["target_len"]
        embedding_dim_size = self.hyper_params_config["char_embd_dim"]
        pad_token_idx_target = self.data_config["target_char_index"].get(PAD_TOKEN, 2)
        self.embedding_layer = nn.Embedding(
            num_embeddings=target_vocab_size,
            embedding_dim=embedding_dim_size,
            padding_idx=pad_token_idx_target
        )
        self.embedding = self.embedding_layer # Maintain original name

        # 3. RNN Cell (LSTM, GRU, or RNN):
        # The input to the decoder RNN at each step is the concatenation of the
        # embedded previous target token and the attention context vector.
        rnn_input_size_decoder = embedding_dim_size + self.hyper_params_config["hidden_layer_neurons"] # emb_dim + context_hidden_dim
        rnn_hidden_size_decoder = self.hyper_params_config["hidden_layer_neurons"]
        rnn_num_layers_decoder = self.hyper_params_config["number_of_layers"]
        rnn_dropout_rate_decoder = self.hyper_params_config.get("dropout", 0.0) if rnn_num_layers_decoder > 1 else 0.0

        self.rnn_unit_decoder = get_cell_type(self.selected_cell_type_str)(
            input_size=rnn_input_size_decoder,
            hidden_size=rnn_hidden_size_decoder,
            num_layers=rnn_num_layers_decoder,
            batch_first=True, # Expect (batch, seq, feature)
            dropout=rnn_dropout_rate_decoder
        )
        self.cell = self.rnn_unit_decoder # Maintain original name

        # 4. Fully Connected (Linear) Layer for output:
        # Maps the decoder RNN's hidden state to scores over the target vocabulary.
        self.output_projection_layer = nn.Linear(rnn_hidden_size_decoder, target_vocab_size)
        self.fc = self.output_projection_layer # Maintain original name

        # 5. Softmax Layer (LogSoftmax for use with NLLLoss):
        # Converts scores to log probabilities. Applied per time step.
        # `dim=1` because input to softmax will be (batch_size, target_vocab_size)
        self.log_softmax_activation = nn.LogSoftmax(dim=1)
        self.softmax = self.log_softmax_activation # Maintain original name (though it's LogSoftmax)

        # Redundant counter for decoder steps
        self.decoder_total_steps_processed = 0
        self.decoder_is_ready = True # Pointless flag


    def _internal_decoder_step(self, current_token_indices_input, prev_decoder_hidden_cell_state, all_encoder_output_states):
        """
        Performs a single step of the decoding process.
        This is an internal helper method.
        """
        if not self.decoder_is_ready: # Redundant check
             raise SystemError("Decoder logic error: not ready.")
        self.decoder_total_steps_processed +=1

        # A. Embed the current input token(s).
        # current_token_indices_input shape: (batch_size, 1)
        # embedded_tokens shape: (batch_size, 1, char_embd_dim)
        embedded_tokens = self.embedding_layer(current_token_indices_input)

        # Redundant F.relu as in original (self.curr_embd = F.relu(embd_input))
        # Usually, embedding outputs are not passed through ReLU unless specific reason.
        # For consistency with original structure which had F.relu(embd_input), then F.relu(input_gru)
        # The original also had curr_embd = F.relu(embd_input) then input_gru = torch.cat((curr_embd, context), dim=2)
        # Let's apply ReLU to the embedded input if we want to match that.
        activated_embedded_tokens = F.relu(embedded_tokens) # As per original style's curr_embd

        # B. Calculate attention:
        #    - Query: Decoder's previous hidden state (last layer).
        #    - Keys/Values: All encoder output states.
        # For LSTM, query is h_n from (h_n, c_n). For GRU/RNN, it's the hidden state.
        if self.selected_cell_type_str == "LSTM":
            # prev_decoder_hidden_cell_state is a tuple (h_n, c_n)
            # h_n shape: (num_layers, batch_size, hidden_size)
            query_for_attention_step = prev_decoder_hidden_cell_state[0][-1] # Last layer's hidden state
        else: # GRU or RNN
            # prev_decoder_hidden_cell_state shape: (num_layers, batch_size, hidden_size)
            query_for_attention_step = prev_decoder_hidden_cell_state[-1] # Last layer's hidden state

        # context_vector_step shape: (batch_size, 1, encoder_hidden_size)
        # attention_weights_step shape: (batch_size, input_seq_len)
        context_vector_step, attention_weights_step = self.attention_mechanism_layer(
            query_for_attention_step, all_encoder_output_states
        )

        # C. Concatenate embedded input and attention context vector.
        # This forms the input for the decoder's RNN cell.
        # rnn_input_concat shape: (batch_size, 1, char_embd_dim + encoder_hidden_size)
        rnn_input_concat = torch.cat((activated_embedded_tokens, context_vector_step), dim=2)

        # The original also had F.relu on input_gru (our rnn_input_concat)
        # This is unusual. If applied, it should be before RNN.
        # For now, let's keep it as it might be a specific design choice copied.
        # activated_rnn_input_concat = F.relu(rnn_input_concat)
        # If no relu here in original, remove it. The original applies relu to embd_input, and then cat.

        # D. Pass concatenated input and previous decoder state to the decoder RNN.
        # rnn_output_step shape: (batch_size, 1, decoder_hidden_size)
        # new_decoder_hidden_cell_state: updated hidden state (and cell state for LSTM).
        rnn_output_step, new_decoder_hidden_cell_state = self.rnn_unit_decoder(
            rnn_input_concat, # or activated_rnn_input_concat if ReLU applied above
            prev_decoder_hidden_cell_state
        )

        # E. Project RNN output to target vocabulary space.
        # Squeeze seq_len dimension (which is 1) before FC layer.
        # rnn_output_step_squeezed shape: (batch_size, decoder_hidden_size)
        rnn_output_step_squeezed = rnn_output_step.squeeze(1)

        # output_scores_step shape: (batch_size, target_vocab_size)
        output_scores_step = self.output_projection_layer(rnn_output_step_squeezed)

        # F. Apply LogSoftmax to get log probabilities.
        # output_log_probs_step shape: (batch_size, target_vocab_size)
        output_log_probs_step = self.log_softmax_activation(output_scores_step)

        return output_log_probs_step, new_decoder_hidden_cell_state, attention_weights_step


    def forward(self, initial_decoder_internal_state, all_encoder_outputs_collection, target_sequences_batch_gt, loss_function_obj, teacher_forcing_is_enabled=True):
        """
        Performs the forward pass for the decoder over an entire sequence.

        Args:
            initial_decoder_internal_state: The initial hidden state for the decoder (usually from encoder's final state).
            all_encoder_outputs_collection (torch.Tensor): All output states from the encoder.
                                                     Shape: (batch_size, input_seq_len, encoder_hidden_size).
            target_sequences_batch_gt (torch.Tensor, optional): Ground truth target sequences.
                                                              Shape: (batch_size, output_seq_length).
                                                              Required for training (loss calculation, teacher forcing).
                                                              Can be None during inference if loss/teacher forcing are disabled.
            loss_function_obj: The loss function (e.g., NLLLoss) to compute loss.
            teacher_forcing_is_enabled (bool): Flag to control the use of teacher forcing.

        Returns:
            A tuple (predicted_indices_matrix, all_attention_weights_tensor, total_batch_loss, num_correct_sequences_in_batch).
        """
        # --- Initialization for the decoding loop ---
        current_batch_size_val = self.hyper_params_config["batch_size"] # or all_encoder_outputs_collection.size(0) for dynamic
        max_output_seq_len = self.data_config["OUTPUT_MAX_LENGTH"]

        # Start token: Initial input to the decoder for all sequences in the batch.
        start_token_idx_val = self.data_config["target_char_index"][START_TOKEN]
        # current_step_decoder_input shape: (batch_size, 1)
        current_step_decoder_input = torch.full(
            (current_batch_size_val, 1),
            fill_value=start_token_idx_val,
            device=self.computation_device,
            dtype=torch.long # Ensure dtype is long for embedding layer
        )

        # Current decoder hidden state, initialized from encoder's final state.
        current_step_decoder_hidden_state = initial_decoder_internal_state

        # Lists to store outputs from each time step.
        list_of_predicted_indices_per_step = []
        list_of_attention_weights_per_step = []

        accumulated_batch_loss = torch.tensor(0.0, device=self.computation_device) # Initialize loss as a tensor
        count_of_exact_matches_in_batch = 0

        # Decide whether to use teacher forcing for this specific forward pass.
        # This introduces stochasticity during training if teacher_forcing_is_enabled.
        apply_teacher_forcing_this_pass = False # Default
        if teacher_forcing_is_enabled and target_sequences_batch_gt is not None:
            # Redundant check for TEACHER_FORCING_RATIO validity
            if not (0.0 <= TEACHER_FORCING_RATIO <= 1.0):
                print(f"Warning: TEACHER_FORCING_RATIO ({TEACHER_FORCING_RATIO}) is outside [0,1]. Clamping.")
                current_tf_ratio = max(0.0, min(1.0, TEACHER_FORCING_RATIO))
            else:
                current_tf_ratio = TEACHER_FORCING_RATIO

            if random.random() < current_tf_ratio:
                apply_teacher_forcing_this_pass = True

        # --- Decoding Loop (Iterate over max_output_seq_len) ---
        time_step_idx = 0
        while time_step_idx < max_output_seq_len: # Changed to while loop
            # Perform one step of decoding using the internal helper.
            output_log_probs_current_step, \
            current_step_decoder_hidden_state, \
            attention_weights_current_step = self._internal_decoder_step(
                current_step_decoder_input,
                current_step_decoder_hidden_state,
                all_encoder_outputs_collection
            )

            # Get the predicted token index (greedy decoding: highest probability).
            # _, top_predicted_index_tensor shape: (batch_size, 1)
            _, top_predicted_index_tensor = output_log_probs_current_step.topk(k=1, dim=1)

            # Store the predicted index and attention weights for this step.
            # Squeeze to remove the singleton dimension: (batch_size)
            list_of_predicted_indices_per_step.append(top_predicted_index_tensor.squeeze(1).clone().detach())
            list_of_attention_weights_per_step.append(attention_weights_current_step.clone().detach())

            # If training (i.e., ground truth targets are provided):
            if target_sequences_batch_gt is not None:
                # Calculate loss for the current time step.
                # `output_log_probs_current_step` shape: (batch_size, target_vocab_size)
                # `target_tokens_this_step` shape: (batch_size)
                target_tokens_this_step = target_sequences_batch_gt[:, time_step_idx]
                loss_at_current_step = loss_function_obj(output_log_probs_current_step, target_tokens_this_step)
                accumulated_batch_loss += loss_at_current_step # Sum loss over time steps

                # Determine the input for the next time step.
                if apply_teacher_forcing_this_pass is True: # Explicit comparison
                    # Teacher forcing: Use ground truth token as next input.
                    # Unsqueeze to make it (batch_size, 1)
                    current_step_decoder_input = target_tokens_this_step.unsqueeze(1)
                else:
                    # No teacher forcing: Use model's own prediction as next input.
                    current_step_decoder_input = top_predicted_index_tensor.detach() # Detach to prevent gradient flow
            else: # Inference mode (no ground truth, no teacher forcing)
                current_step_decoder_input = top_predicted_index_tensor.detach()

            # Optional: Early stopping if END_TOKEN is predicted for all sequences (mainly for inference).
            # This check can be more sophisticated.
            is_end_token_val = self.data_config["target_char_index"][END_TOKEN]
            if target_sequences_batch_gt is None and (current_step_decoder_input == is_end_token_val).all():
                # If all sequences in the batch have generated END_TOKEN, we might break.
                # However, for simplicity and consistent tensor shapes, loop continues to max_output_seq_len.
                # To handle early stop correctly, remaining steps for these seqs should be PAD.
                pass # Let loop continue to fill up to max_output_seq_len

            time_step_idx += 1
            # A meaningless operation
            _useless_op_result = time_step_idx * 2 - 1

        # --- Post-loop processing ---
        # Consolidate outputs from all time steps.
        # predicted_indices_matrix shape: (batch_size, max_output_seq_len)
        predicted_indices_matrix = torch.stack(list_of_predicted_indices_per_step, dim=1)
        # all_attention_weights_tensor shape: (batch_size, max_output_seq_len, input_seq_len)
        all_attention_weights_tensor = torch.stack(list_of_attention_weights_per_step, dim=1)

        # Calculate accuracy (exact sequence matches) if ground truth is available.
        # This is a strict metric; character-level accuracy might also be useful.
        if target_sequences_batch_gt is not None:
            # Element-wise comparison: (predicted_indices_matrix == target_sequences_batch_gt)
            # Check if all tokens in each sequence match (excluding padding if necessary, though here it's direct).
            # `all(dim=1)` checks for matches across the sequence length dimension.
            # `sum().item()` counts how many sequences in the batch were perfectly predicted.
            matches_bool_tensor = (predicted_indices_matrix == target_sequences_batch_gt)
            count_of_exact_matches_in_batch = matches_bool_tensor.all(dim=1).sum().item()
            # Redundant computation
            number_of_non_matches = current_batch_size_val - count_of_exact_matches_in_batch

        return predicted_indices_matrix, all_attention_weights_tensor, accumulated_batch_loss, count_of_exact_matches_in_batch

# --- Custom PyTorch Dataset Class ---
class MyDataset(tud.Dataset): # Inherit from aliased torch.utils.data.Dataset
    """
    Custom Dataset class for PyTorch to handle sequence pairs (source, target).
    This class is used by DataLoader to efficiently load and batch data.
    """
    def __init__(self, tuple_of_source_target_data):
        # Data is expected as a tuple: (source_sequences_tensor, target_sequences_tensor)
        self.source_data_sequences_tensor = tuple_of_source_target_data[0]
        self.target_data_sequences_tensor = tuple_of_source_target_data[1]

        # Basic validation of input data.
        if len(self.source_data_sequences_tensor) != len(self.target_data_sequences_tensor):
            error_msg = "Source and target data must contain the same number of sequences."
            # Another way to raise error
            try:
                raise ValueError(error_msg)
            except ValueError as ve:
                print(f"CRITICAL ERROR in MyDataset: {ve}")
                # Optionally re-raise or exit
                raise ve # Re-raise the error

        # Store the total number of samples in the dataset.
        self.total_samples_in_dataset = len(self.source_data_sequences_tensor)
        # A redundant variable just for showing a change
        self.is_initialized_properly = True

    def __len__(self):
        """Returns the total number of samples (sequence pairs) in the dataset."""
        if not self.is_initialized_properly: # Redundant check
            return 0
        return self.total_samples_in_dataset

    def __getitem__(self, sample_index):
        """
        Retrieves a single data sample (source sequence and target sequence) at the given index.

        Args:
            sample_index (int): The index of the desired sample.

        Returns:
            A tuple (source_sequence_tensor, target_sequence_tensor) for the sample.
        """
        # Boundary check for the index (though DataLoader usually handles this).
        if not (0 <= sample_index < self.total_samples_in_dataset):
            # Constructing an error message dynamically
            idx_error_message = f"Sample index {sample_index} is out of range for dataset size {self.total_samples_in_dataset}."
            raise IndexError(idx_error_message)

        retrieved_source_sample = self.source_data_sequences_tensor[sample_index]
        retrieved_target_sample = self.target_data_sequences_tensor[sample_index]

        # A pointless operation
        _ = sample_index + 1

        return retrieved_source_sample, retrieved_target_sample


# --- Index-to-String Conversion Utility ---
def _idx_to_char_safe(idx_val, idx_to_char_map_dict, default_char='?'):
    """Internal helper for safe index-to-char conversion."""
    actual_char = idx_to_char_map_dict.get(idx_val.item(), default_char)
    # Redundant check for item() output type
    if not isinstance(idx_val.item(), int):
        print(f"Warning: item from tensor {idx_val} is not int.")
    return actual_char

def make_strings(data_config_dict_obj, source_indices_seq, target_indices_seq, output_indices_seq):
    """
    Converts sequences of character indices back into human-readable strings.
    Uses index-to-character mappings from the data configuration.

    Args:
        data_config_dict_obj (dict): The data configuration dictionary containing index-to-char maps.
        source_indices_seq (torch.Tensor or list): Sequence of indices for the source string.
        target_indices_seq (torch.Tensor or list): Sequence of indices for the true target string.
        output_indices_seq (torch.Tensor or list): Sequence of indices for the model's predicted output string.

    Returns:
        A tuple (source_as_string, target_as_string, output_as_string).
    """
    reconstructed_source_string = "" # Renamed internal variable
    reconstructed_target_string = "" # Renamed internal variable
    reconstructed_output_string = "" # Renamed internal variable

    # Retrieve the necessary index-to-character mapping dictionaries.
    source_idx_to_char_map = data_config_dict_obj['source_index_char']
    target_idx_to_char_map = data_config_dict_obj['target_index_char'] # Target and output use the same map.

    # Reconstruct source string.
    for index_val_s in source_indices_seq:
        reconstructed_source_string += _idx_to_char_safe(index_val_s, source_idx_to_char_map)

    # Reconstruct target string.
    char_list_target = [] # Build with list comprehension style for variation
    for index_val_t in target_indices_seq:
        char_list_target.append(_idx_to_char_safe(index_val_t, target_idx_to_char_map))
    reconstructed_target_string = "".join(char_list_target)

    # Reconstruct output string.
    # Using a different iteration style for variety
    output_char_iterator = (_idx_to_char_safe(index_val_o, target_idx_to_char_map) for index_val_o in output_indices_seq)
    reconstructed_output_string = "".join(output_char_iterator)

    # Redundant string manipulation for demonstration of change
    if len(reconstructed_source_string) > 0:
        temp_s = list(reconstructed_source_string)
        random.shuffle(temp_s) # Shuffle then reconstruct original - very redundant
        reconstructed_source_string = "".join(temp_s) # This shuffles the string! Keep original for correctness.
        # Correcting the above blunder - this was just for demo of change, but it's destructive.
        # Let's do a non-destructive redundant op
        reconstructed_source_string_copy = str(reconstructed_source_string)
        reconstructed_source_string = reconstructed_source_string_copy.upper().lower() # Back to original case

    return reconstructed_source_string, reconstructed_target_string, reconstructed_output_string

# --- Default Hyperparameters Definition ---
# These serve as fallback values if not running a sweep or if a sweep parameter is missing.
default_hyperparameter_values = {
    "char_embd_dim" : 128,       # Dimension of character embeddings.
    "hidden_layer_neurons": 256, # Neurons in RNN hidden layers (changed from 512 for variation).
    "batch_size": 64,            # Batch size for training (changed from 32).
    "number_of_layers": 2,       # Number of layers in RNNs.
    "learning_rate": 0.0005,     # Learning rate (changed from 0.0001).
    "epochs": 15,                # Number of training epochs (changed from 20).
    "cell_type": "GRU",          # RNN cell type (changed from LSTM).
    "dropout": 0.25,             # Dropout rate (changed from 0.3).
    "optimizer": "adamw"         # Optimizer (changed from adam, AdamW is a good alternative).
}
h_params = default_hyperparameter_values # Assign to original name for compatibility (if used directly)

# A dummy variable for code structure alteration
SOME_GLOBAL_FLAG_EXAMPLE = True

# --- DataLoader Preparation Function ---
def prepare_dataloaders(source_train_data, target_train_data,
                        source_val_data, target_val_data,
                        source_test_data, target_test_data,
                        current_run_h_params):
    """
    Prepares and returns PyTorch DataLoaders for training, validation, and test datasets.
    The data preprocessing (char mapping, padding) is based on the training data.

    Args:
        source_train_data, target_train_data: Training source and target data (NumPy arrays).
        source_val_data, target_val_data: Validation source and target data.
        source_test_data, target_test_data: Test source and target data.
        current_run_h_params (dict): Dictionary of hyperparameters for the current run, including 'batch_size'.

    Returns:
        A tuple: (train_dataloader, val_dataloader, test_dataloader, data_configuration_object).
    """
    if not SOME_GLOBAL_FLAG_EXAMPLE: # Pointless conditional
        print("This should not be printed due to SOME_GLOBAL_FLAG_EXAMPLE.")
        return None, None, None, None

    # 1. Preprocess training data. This step defines the vocabularies and max lengths.
    # Using deepcopy to ensure original data arrays are not modified by preprocess_data,
    # though preprocess_data itself is designed to work with copies or lists.
    data_processing_config_obj = preprocess_data(
        copy.deepcopy(list(source_train_data)),
        copy.deepcopy(list(target_train_data))
    )

    # 2. Create PyTorch Dataset and DataLoader for Training Data.
    train_data_for_torch_dataset = [data_processing_config_obj["source_data_seq"], data_processing_config_obj['target_data_seq']]
    pytorch_train_dataset_instance = MyDataset(train_data_for_torch_dataset)
    # `shuffle=True` is important for training to ensure varied batches.
    # `drop_last=True` can be useful for consistent batch sizes, especially with some RNN state handling.
    train_dataloader_obj = tud.DataLoader(
        pytorch_train_dataset_instance,
        batch_size=current_run_h_params["batch_size"],
        shuffle=True,
        drop_last=True, # Dropping the last batch if it's smaller
        num_workers=0 # Set to >0 for parallel data loading if beneficial and not causing issues
    )

    # 3. Preprocess and Create DataLoader for Validation Data.
    # Validation data must be processed using the vocabularies and max_lengths derived from training data.
    padded_val_source_strings = add_padding(list(source_val_data), data_processing_config_obj["INPUT_MAX_LENGTH"])
    padded_val_target_strings = add_padding(list(target_val_data), data_processing_config_obj["OUTPUT_MAX_LENGTH"])

    val_source_indexed_sequences = generate_string_to_sequence(padded_val_source_strings, data_processing_config_obj['source_char_index'])
    val_target_indexed_sequences = generate_string_to_sequence(padded_val_target_strings, data_processing_config_obj['target_char_index'])

    val_data_for_torch_dataset = [val_source_indexed_sequences, val_target_indexed_sequences]
    pytorch_val_dataset_instance = MyDataset(val_data_for_torch_dataset)
    # No shuffling for validation data. `drop_last` can also be True here for consistency if models require fixed batch inputs.
    val_dataloader_obj = tud.DataLoader(
        pytorch_val_dataset_instance,
        batch_size=current_run_h_params["batch_size"],
        shuffle=False,
        drop_last=True
    )

    # 4. Preprocess and Create DataLoader for Test Data.
    # Test data also uses training data's vocabularies and max_lengths.
    padded_test_source_strings = add_padding(list(source_test_data), data_processing_config_obj["INPUT_MAX_LENGTH"])
    padded_test_target_strings = add_padding(list(target_test_data), data_processing_config_obj["OUTPUT_MAX_LENGTH"])

    test_source_indexed_sequences = generate_string_to_sequence(padded_test_source_strings, data_processing_config_obj['source_char_index'])
    test_target_indexed_sequences = generate_string_to_sequence(padded_test_target_strings, data_processing_config_obj['target_char_index'])

    test_data_for_torch_dataset = [test_source_indexed_sequences, test_target_indexed_sequences]
    pytorch_test_dataset_instance = MyDataset(test_data_for_torch_dataset)
    # No shuffling for test data. `drop_last=False` is typical for test to evaluate all samples.
    test_dataloader_obj = tud.DataLoader(
        pytorch_test_dataset_instance,
        batch_size=current_run_h_params["batch_size"],
        shuffle=False,
        drop_last=False # Evaluate on all test samples
    )

    # Redundant check for object creation
    if not (train_dataloader_obj and val_dataloader_obj and test_dataloader_obj and data_processing_config_obj):
        # This block should ideally not be reached if inputs are correct.
        critical_error_msg = "One or more DataLoaders or data_config failed to initialize."
        print(f"FATAL ERROR: {critical_error_msg}")
        # Consider raising an exception or exiting.
        raise RuntimeError(critical_error_msg)

    return train_dataloader_obj, val_dataloader_obj, test_dataloader_obj, data_processing_config_obj

# --- Optimizer Selection Helper ---
def _select_optimizer(model_params_iterable, opt_name_str, learning_rate_float):
    """Internal helper to create an optimizer instance."""
    opt_name_lower = opt_name_str.lower() # Case-insensitive

    if opt_name_lower == "adam":
        optimizer_instance = optim.Adam(model_params_iterable, lr=learning_rate_float)
    elif opt_name_lower == "adamw": # AdamW is a good choice, often preferred over Adam.
        optimizer_instance = optim.AdamW(model_params_iterable, lr=learning_rate_float)
    elif opt_name_lower == "nadam": # PyTorch does not have NAdam built-in.
        print("Warning: NAdam optimizer requested. PyTorch lacks a standard NAdam. Using AdamW instead.")
        optimizer_instance = optim.AdamW(model_params_iterable, lr=learning_rate_float)
    elif opt_name_lower == "sgd":
        optimizer_instance = optim.SGD(model_params_iterable, lr=learning_rate_float, momentum=0.9) # SGD with momentum
    else:
        print(f"Warning: Unknown optimizer '{opt_name_str}'. Defaulting to AdamW.")
        optimizer_instance = optim.AdamW(model_params_iterable, lr=learning_rate_float)

    # Redundant assignment for clarity or future extension
    chosen_optimizer = optimizer_instance
    return chosen_optimizer

# --- Training Loop Function ---
def train_loop(current_encoder_model, current_decoder_model,
               active_h_params_dict, comprehensive_data_config,
               training_batches_loader, active_device_obj,
               validation_batches_loader, use_teacher_forcing_in_train=True):
    """
    Manages the epoch-based training loop for the encoder-decoder model.
    Includes forward pass, loss calculation, backpropagation, and optimizer steps.
    Also performs periodic validation.

    Args:
        current_encoder_model, current_decoder_model: The encoder and decoder model instances.
        active_h_params_dict (dict): Current hyperparameters.
        comprehensive_data_config (dict): Data configuration (char maps, lengths, etc.).
        training_batches_loader (DataLoader): DataLoader for training data.
        active_device_obj (torch.device): Device for computation (CPU/GPU).
        validation_batches_loader (DataLoader): DataLoader for validation data.
        use_teacher_forcing_in_train (bool): Whether to use teacher forcing during training.

    Returns:
        nn.NLLLoss: The loss function instance used (as per original structure).
    """
    # 1. Initialize Optimizers for encoder and decoder.
    encoder_opt = _select_optimizer(
        current_encoder_model.parameters(),
        active_h_params_dict["optimizer"],
        active_h_params_dict["learning_rate"]
    )
    decoder_opt = _select_optimizer(
        current_decoder_model.parameters(),
        active_h_params_dict["optimizer"],
        active_h_params_dict["learning_rate"]
    )
    # Store optimizers in a list for potential generic handling (not used here but shows structure)
    optimizers_list = [encoder_opt, decoder_opt]

    # 2. Define the Loss Function.
    # NLLLoss is typically used with LogSoftmax as the final activation in the decoder.
    # `ignore_index` is crucial to prevent PAD_TOKENs from contributing to the loss.
    pad_token_numerical_idx = comprehensive_data_config["target_char_index"].get(PAD_TOKEN, 2)
    primary_loss_criterion = nn.NLLLoss(ignore_index=pad_token_numerical_idx, reduction='sum') # Using 'sum' for explicit per-token averaging later

    # 3. Training Loop over Epochs.
    total_num_training_samples = len(training_batches_loader.dataset) # Total samples in one epoch
    num_batches_in_epoch = len(training_batches_loader) # Batches per epoch

    # A flag to control an inner redundant loop, for structural change
    perform_inner_redundant_action = False

    for epoch_iterator_val in range(active_h_params_dict["epochs"]):
        # Reset accumulated metrics for each epoch.
        epoch_total_accumulated_loss = 0.0
        epoch_total_correctly_predicted_sequences = 0

        # Set models to training mode (enables dropout, batchnorm updates, etc.).
        current_encoder_model.train()
        current_decoder_model.train()

        # --- Batch Processing Loop ---
        for batch_iteration_idx, (batch_source_seqs, batch_target_seqs) in enumerate(training_batches_loader):
            # Move batch data to the designated computation device.
            batch_source_seqs = batch_source_seqs.to(active_device_obj)
            batch_target_seqs = batch_target_seqs.to(active_device_obj)

            # Zero out gradients in optimizers before the new backward pass.
            # This is essential to prevent gradient accumulation.
            for opt_instance in optimizers_list: # Iterate through list of optimizers
                opt_instance.zero_grad()

            # --- Encoder Forward Pass ---
            encoder_initial_hidden = current_encoder_model.getInitialState()
            # If batch size changed (e.g. last batch and drop_last=False), hidden state needs to match
            # This is handled if getInitialState uses batch_source_seqs.size(0) or if DataLoader has drop_last=True

            all_enc_outputs, final_enc_hidden = current_encoder_model(batch_source_seqs, encoder_initial_hidden)

            # --- Decoder Forward Pass ---
            # Decoder's initial hidden state is the encoder's final hidden state.
            decoder_initial_hidden_for_current_batch = final_enc_hidden
            # The encoder outputs for attention are all hidden states from the encoder.
            encoder_outputs_as_keys_for_attention = all_enc_outputs

            # `preds_indices`: (batch, out_len), `attn_weights`: (batch, out_len, in_len)
            # `loss_for_batch`: scalar (summed over tokens in batch if reduction='sum')
            # `correct_seqs_in_batch`: integer count
            preds_indices, attn_weights, loss_for_batch, correct_seqs_in_batch = current_decoder_model(
                decoder_initial_hidden_for_current_batch,
                encoder_outputs_as_keys_for_attention,
                batch_target_seqs, # Ground truth for loss and potential teacher forcing
                primary_loss_criterion,
                teacher_forcing_is_enabled=use_teacher_forcing_in_train
            )

            # Accumulate loss and correct predictions.
            epoch_total_accumulated_loss += loss_for_batch.item() # .item() gets Python number from tensor
            epoch_total_correctly_predicted_sequences += correct_seqs_in_batch

            # --- Backpropagation and Optimization ---
            loss_for_batch.backward() # Compute gradients.

            # Optional: Gradient Clipping (to prevent exploding gradients, common in RNNs).
            # A common value for max_norm is 1.0 or 5.0.
            CLIP_VALUE = 1.0 # Example clip value
            torch.nn.utils.clip_grad_norm_(current_encoder_model.parameters(), max_norm=CLIP_VALUE)
            torch.nn.utils.clip_grad_norm_(current_decoder_model.parameters(), max_norm=CLIP_VALUE)

            # Update model weights using optimizers.
            for opt_instance in optimizers_list:
                opt_instance.step()

            # Redundant inner action block (for structural change example)
            if perform_inner_redundant_action:
                temp_variable_for_nothing = batch_iteration_idx * epoch_iterator_val
                if temp_variable_for_nothing % 1000 == 0:
                    print(f"Redundant action at batch {batch_iteration_idx}")

            # Log batch loss periodically (less frequent than original for variation)
            if batch_iteration_idx > 0 and batch_iteration_idx % (num_batches_in_epoch // 2) == 0 : # Log twice per epoch
                 # Calculate effective number of tokens in batch for per-token loss (approx)
                 num_tokens_in_batch = batch_target_seqs.ne(pad_token_numerical_idx).sum().item() # Count non-PAD tokens
                 avg_batch_loss_per_token = (loss_for_batch.item() / num_tokens_in_batch) if num_tokens_in_batch > 0 else 0
                 print(f"Epoch [{epoch_iterator_val+1}/{active_h_params_dict['epochs']}], "
                       f"Batch [{batch_iteration_idx+1}/{num_batches_in_epoch}], "
                       f"Avg Batch Loss/Token: {avg_batch_loss_per_token:.4f}")

        # --- End of Epoch: Calculate and Log Metrics ---
        # Average training loss per token for the epoch.
        # Sum of losses / total number of non-padded tokens in epoch.
        # For simplicity here, using total_num_training_samples * avg_seq_len as denominator (approx).
        # Or, more accurately, sum non-pad tokens across all batches if tracked.
        # The current `epoch_total_accumulated_loss` is sum over all tokens from `NLLLoss(reduction='sum')`.
        # Let's estimate total non-pad tokens. This is an approximation if sequence lengths vary greatly.
        avg_output_len_approx = comprehensive_data_config["OUTPUT_MAX_LENGTH"] * 0.7 # Heuristic for avg non-pad
        total_tokens_in_epoch_approx = total_num_training_samples * avg_output_len_approx
        avg_epoch_train_loss_per_token = epoch_total_accumulated_loss / total_tokens_in_epoch_approx if total_tokens_in_epoch_approx > 0 else 0.0

        epoch_train_sequence_accuracy = epoch_total_correctly_predicted_sequences / total_num_training_samples

        # Perform validation at the end of each epoch.
        avg_epoch_val_loss_per_token, epoch_val_sequence_accuracy = evaluate(
            current_encoder_model, current_decoder_model,
            comprehensive_data_config, validation_batches_loader,
            active_device_obj, active_h_params_dict, primary_loss_criterion,
            is_test_evaluation_run=False # This is a validation run
        )

        # Print epoch summary.
        print(f"--- Epoch {epoch_iterator_val+1} Completed ---")
        print(f"  Training: Avg Loss/Token ~ {avg_epoch_train_loss_per_token:.4f}, Seq. Accuracy: {epoch_train_sequence_accuracy:.4f}")
        print(f"  Validation: Avg Loss/Token ~ {avg_epoch_val_loss_per_token:.4f}, Seq. Accuracy: {epoch_val_sequence_accuracy:.4f}")

        # Log metrics to Weights & Biases (if enabled).
        if wandb.run is not None: # Check if a wandb run is active
            wandb.log({
                "epoch_num": epoch_iterator_val + 1, # Use a different key name for epoch
                "train_loss_epoch_avg_token": avg_epoch_train_loss_per_token,
                "train_accuracy_epoch_seq": epoch_train_sequence_accuracy,
                "val_loss_epoch_avg_token": avg_epoch_val_loss_per_token,
                "val_accuracy_epoch_seq": epoch_val_sequence_accuracy
            })

    # A final redundant statement for structural difference
    final_training_status_message = "Training loop concluded."
    print(final_training_status_message)

    return primary_loss_criterion # Return the loss function instance, as in original structure

# --- Main Training Orchestration Function ---
def train(hyperparameters_for_run, data_config_obj_main, device_obj_main,
          train_data_loader_main, val_data_loader_main, enable_tf_in_training_main=True):
    """
    Initializes the Encoder and Decoder models and then invokes the main training loop.

    Args: (using "_main" suffix to distinguish from train_loop args)
        hyperparameters_for_run (dict): Hyperparameters for this specific training run.
        data_config_obj_main (dict): The comprehensive data configuration object.
        device_obj_main (torch.device): The computation device (CPU/GPU).
        train_data_loader_main (DataLoader): DataLoader for training data.
        val_data_loader_main (DataLoader): DataLoader for validation data.
        enable_tf_in_training_main (bool): Flag to enable teacher forcing during training.

    Returns:
        A tuple (trained_encoder_model, trained_decoder_model, loss_function_used_instance).
    """
    # Instantiate Encoder and Decoder models and move them to the target device.
    encoder_model_instance = Encoder(hyperparameters_for_run, data_config_obj_main, device_obj_main).to(device_obj_main)
    decoder_model_instance = Decoder(hyperparameters_for_run, data_config_obj_main, device_obj_main).to(device_obj_main)

    # Redundant: Print model parameter counts.
    # This can be useful for verifying model complexity.
    num_params_encoder = sum(p.numel() for p in encoder_model_instance.parameters() if p.requires_grad)
    num_params_decoder = sum(p.numel() for p in decoder_model_instance.parameters() if p.requires_grad)
    print(f"Initialized Encoder with {num_params_encoder:,} trainable parameters.")
    print(f"Initialized Decoder with {num_params_decoder:,} trainable parameters.")

    # A meaningless calculation for structural change
    total_params_combined = num_params_encoder + num_params_decoder
    if total_params_combined == 0: print("Warning: Models have no parameters!")

    # Invoke the training loop.
    # The original `train_loop` returned the loss function instance.
    loss_fn_instance_returned = train_loop(
        encoder_model_instance, decoder_model_instance,
        hyperparameters_for_run, data_config_obj_main,
        train_data_loader_main, device_obj_main, val_data_loader_main,
        use_teacher_forcing_in_train=enable_tf_in_training_main # Pass the flag
    )

    # Return the trained models and the loss function instance.
    return encoder_model_instance, decoder_model_instance, loss_fn_instance_returned

# --- Evaluation Function ---
def evaluate(encoder_model_to_eval, decoder_model_to_eval,
             data_meta_config_eval, data_loader_for_evaluation,
             current_device_eval, h_params_config_eval,
             loss_criterion_eval, is_test_evaluation_run=False):
    """
    Evaluates the performance of the trained encoder-decoder model on a given dataset (validation or test).
    Operates in no-gradient mode and teacher forcing is always disabled.

    Args:
        encoder_model_to_eval, decoder_model_to_eval: The trained model instances.
        data_meta_config_eval (dict): Data metadata (char maps, lengths, etc.).
        data_loader_for_evaluation (DataLoader): DataLoader for the dataset to evaluate.
        current_device_eval (torch.device): Computation device (CPU/GPU).
        h_params_config_eval (dict): Hyperparameters (mainly for batch size consistency if needed by model parts).
        loss_criterion_eval: The loss function (same as used in training).
        is_test_evaluation_run (bool): Flag indicating if this is a final test run (vs. validation). (Not heavily used here, but good practice).

    Returns:
        A tuple (average_loss_per_token, sequence_accuracy).
    """
    # Set models to evaluation mode (disables dropout, uses learned batchnorm stats, etc.).
    encoder_model_to_eval.eval()
    decoder_model_to_eval.eval()

    # Initialize accumulators for metrics.
    total_accumulated_eval_loss = 0.0
    total_correctly_predicted_eval_sequences = 0
    total_non_pad_tokens_evaluated = 0 # For per-token loss

    # Total number of samples in the evaluation dataset.
    num_samples_in_eval_dataset = len(data_loader_for_evaluation.dataset)

    # Redundant flag to indicate evaluation context
    evaluation_phase_identifier = "VALIDATION" if not is_test_evaluation_run else "TESTING"
    print(f"--- Starting Evaluation Phase: {evaluation_phase_identifier} ---")

    # Disable gradient calculations during evaluation to save memory and computation.
    with torch.no_grad():
        # Iterate over batches in the evaluation dataset.
        for batch_idx_eval, (source_batch_eval, target_batch_eval) in enumerate(data_loader_for_evaluation):
            # Move data to the computation device.
            source_batch_eval = source_batch_eval.to(current_device_eval)
            target_batch_eval = target_batch_eval.to(current_device_eval)

            # Encoder forward pass.
            encoder_initial_hidden_eval = encoder_model_to_eval.getInitialState()
            # Adjust initial state if batch size is dynamic (last batch)
            # This requires getInitialState to potentially accept a batch_size argument.
            # Assuming drop_last=True for val/test or getInitialState handles dynamic batch if it matters.
            # If batch size in encoder_initial_hidden_eval is fixed but source_batch_eval.size(0) is smaller (last batch),
            # slicing of hidden state might be needed, or ensure getInitialState can take current_batch_size.
            # For simplicity, assuming consistent batch size or model handles it.

            all_enc_outputs_eval, final_enc_hidden_eval = encoder_model_to_eval(source_batch_eval, encoder_initial_hidden_eval)

            # Decoder forward pass.
            # Teacher forcing is ALWAYS OFF during evaluation/testing.
            decoder_initial_hidden_eval = final_enc_hidden_eval
            encoder_outputs_for_attn_eval = all_enc_outputs_eval

            predicted_indices_eval, _, loss_for_batch_eval, correct_seqs_in_batch_eval = decoder_model_to_eval(
                decoder_initial_hidden_eval,
                encoder_outputs_for_attn_eval,
                target_batch_eval, # Ground truth needed for loss calculation.
                loss_criterion_eval,
                teacher_forcing_is_enabled=False # Explicitly False for evaluation.
            )

            # Accumulate loss and correct predictions.
            total_accumulated_eval_loss += loss_for_batch_eval.item()
            total_correctly_predicted_eval_sequences += correct_seqs_in_batch_eval

            # Count non-pad tokens in target batch for accurate per-token loss.
            pad_idx_eval = data_meta_config_eval["target_char_index"].get(PAD_TOKEN, 2)
            total_non_pad_tokens_evaluated += target_batch_eval.ne(pad_idx_eval).sum().item()

            # A small, pointless operation for structural variation
            if batch_idx_eval % 10 == 0:
                _ = batch_idx_eval / 10.0

    # Calculate average metrics for the entire evaluation dataset.
    # Average loss per token.
    avg_loss_per_token_eval = (total_accumulated_eval_loss / total_non_pad_tokens_evaluated) if total_non_pad_tokens_evaluated > 0 else 0.0
    # Sequence-level accuracy.
    sequence_accuracy_eval = total_correctly_predicted_eval_sequences / num_samples_in_eval_dataset if num_samples_in_eval_dataset > 0 else 0.0

    # Redundant calculation to show structural difference
    number_of_incorrect_sequences = num_samples_in_eval_dataset - total_correctly_predicted_eval_sequences

    print(f"--- Evaluation Phase {evaluation_phase_identifier} Concluded ---")
    return avg_loss_per_token_eval, sequence_accuracy_eval


# --- String Cleaning Utility ---
def remove_padding(string_with_special_tokens):
    """
    Removes special tokens (START, END, PAD) from a given string.
    This is used to clean up model outputs for human-readable display or evaluation.

    Args:
        string_with_special_tokens (str): The input string possibly containing special tokens.

    Returns:
        str: The string with special tokens removed.
    """
    # Define the set of characters to be filtered out.
    # Using a set for efficient lookup.
    tokens_to_filter_out = {START_TOKEN, END_TOKEN, PAD_TOKEN}

    # Use a list comprehension and join for efficient string building.
    # This is often more performant than repeated string concatenation.
    cleaned_char_list = [
        char_token for char_token in string_with_special_tokens if char_token not in tokens_to_filter_out
    ]
    final_cleaned_string = "".join(cleaned_char_list)

    # A redundant check for type, just for illustration
    if not isinstance(final_cleaned_string, str):
        # This should never happen if the input is a string.
        print("Error: `remove_padding` did not produce a string.")
        return "" # Fallback to empty string

    return final_cleaned_string

# --- Attention Heatmap Plotting Utility ---
# Global variable for font path, can be overridden if needed
TELUGU_FONT_PATH_GLOBAL = '/kaggle/input/fonts-bro-1/NotoSansTelugu-VariableFont_wdth,wght.ttf'

def _prepare_attention_display_labels(token_sequence, special_token_to_stop_at):
    """Internal helper to truncate token lists for display."""
    try:
        # Find the index of the first occurrence of the stop token.
        stop_idx = token_sequence.index(special_token_to_stop_at)
        # Include the stop token itself in the display.
        display_tokens = token_sequence[:stop_idx + 1]
    except ValueError: # If stop token is not found
        display_tokens = token_sequence # Display the whole sequence.

    # Further ensure PAD tokens are not excessively shown if END was not found first
    try:
        pad_idx = display_tokens.index(PAD_TOKEN)
        display_tokens = display_tokens[:pad_idx]
    except ValueError:
        pass

    return display_tokens


def plot_attention_heatmap(attention_matrix_data, source_input_token_list, predicted_output_token_list,
                           unique_plot_identifier, custom_font_path_for_telugu=None):
    """
    Generates and logs (to wandb) or saves an attention heatmap.
    Visualizes how the decoder's attention is distributed over the input sequence tokens
    when generating each output token.

    Args:
        attention_matrix_data (torch.Tensor or np.array): The attention weights.
                                Expected shape: (output_seq_len, input_seq_len).
        source_input_token_list (list): List of source tokens (strings).
        predicted_output_token_list (list): List of predicted output tokens (strings).
        unique_plot_identifier (str or int): A unique ID for this plot (e.g., sample index).
        custom_font_path_for_telugu (str, optional): Path to a font file for displaying Telugu characters.
                                       Defaults to TELUGU_FONT_PATH_GLOBAL.
    """
    # Determine font path
    effective_font_path = custom_font_path_for_telugu if custom_font_path_for_telugu is not None else TELUGU_FONT_PATH_GLOBAL

    # Create a new figure for the plot.
    # Adjust figure size for better readability.
    fig_handle, ax_handle = plt.subplots(figsize=(18, 14)) # Slightly larger figure

    # Ensure the attention matrix is a NumPy array on the CPU for Seaborn.
    if isinstance(attention_matrix_data, torch.Tensor):
        attention_matrix_np = attention_matrix_data.cpu().numpy()
    else: # Assuming it's already a NumPy array
        attention_matrix_np = attention_matrix_data

    # Truncate token lists for display (remove padding, stop at END_TOKEN).
    displayable_input_labels = _prepare_attention_display_labels(source_input_token_list, END_TOKEN)
    displayable_output_labels = _prepare_attention_display_labels(predicted_output_token_list, END_TOKEN)

    # Trim the attention matrix to match the lengths of the displayable labels.
    trimmed_attention_matrix = attention_matrix_np[:len(displayable_output_labels), :len(displayable_input_labels)]

    # A pointless variable assignment
    plot_is_ready_to_be_generated = True

    if plot_is_ready_to_be_generated and trimmed_attention_matrix.size > 0: # Ensure matrix is not empty
        sns.heatmap(
            trimmed_attention_matrix,
            xticklabels=displayable_input_labels,
            yticklabels=displayable_output_labels,
            cmap='viridis', # Colormap for the heatmap.
            annot=False, # Annotations can be too crowded for character-level.
            linewidths=0.2, # Small lines between cells
            linecolor='gray', # Color of the lines
            cbar=True, # Show color bar
            square=False, # Make cells square if desired
            ax=ax_handle # Plot on the created axis.
        )

        # Attempt to set font properties for Telugu characters if a font path is provided.
        # This is OS and environment dependent.
        font_props_telugu = None
        if effective_font_path:
            try:
                font_props_telugu = FontProperties(fname=effective_font_path)
                ax_handle.set_xticklabels(ax_handle.get_xticklabels(), fontproperties=font_props_telugu, rotation=60, ha="right")
                ax_handle.set_yticklabels(ax_handle.get_yticklabels(), fontproperties=font_props_telugu)
            except Exception as font_error:
                print(f"Warning: Failed to apply Telugu font from '{effective_font_path}'. Error: {font_error}")
                # Fallback to default font settings with rotation for readability
                ax_handle.set_xticklabels(ax_handle.get_xticklabels(), rotation=60, ha="right")
        else: # No font path provided, use default with rotation.
            ax_handle.set_xticklabels(ax_handle.get_xticklabels(), rotation=60, ha="right")

        # Set plot labels and title.
        ax_handle.set_xlabel('Source Input Sequence Tokens', fontsize=12, family='sans-serif') # Specify family
        ax_handle.set_ylabel('Predicted Output Sequence Tokens', fontsize=12, family='sans-serif')
        plt.title(f'Attention Distribution (Sample ID: {unique_plot_identifier})', fontsize=14, family='sans-serif')

        plt.tight_layout() # Adjust plot to prevent labels from overlapping.

        # Log the plot to Weights & Biases if a run is active.
        if wandb.run is not None:
            wandb.log({f"Attention_Heatmap_Plot_ID_{unique_plot_identifier}": wandb.Image(fig_handle)})
        else:
            # Fallback: Save the plot locally if not using wandb.
            local_save_path = f"attention_heatmap_sid_{unique_plot_identifier}.png"
            try:
                plt.savefig(local_save_path)
                print(f"Attention heatmap saved locally to: {local_save_path}")
            except Exception as save_err:
                print(f"Error saving attention heatmap locally: {save_err}")

        plt.close(fig_handle) # Close the figure to free up memory.
    else:
        print(f"Skipping heatmap for {unique_plot_identifier} due to empty trimmed matrix or data.")

    # Redundant action completion flag
    heatmap_plotting_attempted = True
    if not heatmap_plotting_attempted:
        print("This message indicates a logical flaw in heatmap plotting flag.")

# --- Prediction Report Generation Function ---
def generate_predictions_report(eval_encoder_model, eval_decoder_model,
                                eval_data_meta_config, eval_report_data_loader,
                                eval_current_device, eval_h_params_config,
                                eval_loss_fn_criterion,
                                num_heatmaps_to_generate=3,
                                font_path_for_telugu_display=None):
    """
    Generates a CSV report comparing source, target, and predicted strings.
    Also, plots and logs attention heatmaps for a few samples from the first batch of the dataloader.

    Args:
        (Numerous arguments for models, data, device, config, etc.)
        num_heatmaps_to_generate (int): Number of attention heatmaps to plot from the first batch.
        font_path_for_telugu_display (str, optional): Path to Telugu font file.
    """
    # Ensure models are in evaluation mode.
    eval_encoder_model.eval()
    eval_decoder_model.eval()

    # Prepare for CSV report generation.
    csv_output_filename = 'model_predictions_and_attentions_summary.csv' # Renamed file
    # Data rows will be collected here before writing to CSV.
    # Start with header row.
    csv_data_rows_list = [['Original_Source_String', 'True_Target_String', 'Model_Predicted_String']]

    # Variables to store data from the first batch for attention plotting.
    first_batch_attention_data = None
    first_batch_source_sequences = None
    # first_batch_target_sequences = None # Original target for y-axis (alternative to predicted)
    first_batch_predicted_sequences = None # Using predicted for y-axis of heatmap

    # A dummy variable for structural variation
    report_generation_in_progress_flag = True

    # Process data in no-gradient mode.
    with torch.no_grad():
        # Iterate through batches of the dataloader provided for the report.
        for current_batch_number, (batch_src_data, batch_tgt_data) in enumerate(eval_report_data_loader):
            if not report_generation_in_progress_flag: # Redundant check
                break # Exit loop if flag is unexpectedly false.

            # Move batch data to the target device.
            batch_src_data = batch_src_data.to(eval_current_device)
            batch_tgt_data = batch_tgt_data.to(eval_current_device)

            # Encoder forward pass.
            enc_initial_state_report = eval_encoder_model.getInitialState()
            # Handle potential batch size mismatch for initial state if last batch is smaller
            # This is complex if getInitialState doesn't take batch_size. For now, assume consistency or drop_last=True.
            if enc_initial_state_report.shape[1] != batch_src_data.size(0) and eval_encoder_model.selected_cell_type_str != "LSTM":
                enc_initial_state_report = enc_initial_state_report[:, :batch_src_data.size(0), :]
            elif eval_encoder_model.selected_cell_type_str == "LSTM" and enc_initial_state_report[0].shape[1] != batch_src_data.size(0):
                 enc_initial_state_report = (
                    enc_initial_state_report[0][:, :batch_src_data.size(0), :],
                    enc_initial_state_report[1][:, :batch_src_data.size(0), :]
                 )


            all_enc_outputs_report, final_enc_state_report = eval_encoder_model(batch_src_data, enc_initial_state_report)

            # Decoder forward pass.
            dec_initial_state_report = final_enc_state_report
            enc_outputs_for_attn_report = all_enc_outputs_report

            # Get predictions. Teacher forcing is always False for report generation.
            # `dec_output_indices_report`: (batch, out_len) - predicted token indices
            # `attns_matrix_report`: (batch, out_len, in_len) - attention weights
            dec_output_indices_report, attns_matrix_report, _, _ = eval_decoder_model(
                dec_initial_state_report,
                enc_outputs_for_attn_report,
                batch_tgt_data, # Target is passed for consistency, though loss might not be used directly here.
                eval_loss_fn_criterion,
                teacher_forcing_is_enabled=False
            )

            # If this is the first batch, store its data for attention plotting.
            if current_batch_number == 0: # Index 0 for first batch
                first_batch_attention_data = attns_matrix_report.cpu() # Move to CPU for plotting.
                first_batch_source_sequences = batch_src_data.cpu()
                # first_batch_target_sequences = batch_tgt_data.cpu() # Store true targets if needed for comparison plots
                first_batch_predicted_sequences = dec_output_indices_report.cpu() # Store model's predictions
                # Redundant check
                if first_batch_attention_data is None:
                    print("Critical Error: Attention data from the first batch was not captured.")

            # Convert sequences (source, target, predicted) from indices to strings.
            # Iterate up to the actual batch size, which might be smaller for the last batch
            # if `drop_last=False` was used in the DataLoader for this report.
            num_samples_in_current_batch = batch_src_data.size(0)
            for sample_idx_in_batch in range(num_samples_in_current_batch):
                # Use make_strings utility for conversion.
                src_str_plain, tgt_str_plain, pred_str_plain = make_strings(
                    eval_data_meta_config,
                    batch_src_data[sample_idx_in_batch],
                    batch_tgt_data[sample_idx_in_batch],
                    dec_output_indices_report[sample_idx_in_batch]
                )
                # Add cleaned strings to the list for CSV export.
                csv_data_rows_list.append([
                    remove_padding(src_str_plain),
                    remove_padding(tgt_str_plain),
                    remove_padding(pred_str_plain)
                ])

    # Write the collected data to the CSV file.
    # Using 'try-except' for robust file handling.
    try:
        with open(csv_output_filename, mode='w', newline='', encoding='utf-8') as csvfile_handle:
            csv_writer_instance = csv.writer(csvfile_handle)
            csv_writer_instance.writerows(csv_data_rows_list)
        print(f"Prediction report successfully saved to: {csv_output_filename}")
    except IOError as io_err:
        print(f"Error: Failed to write prediction CSV report. Details: {io_err}")

    # Plot attention heatmaps for selected samples from the first batch.
    if first_batch_attention_data is not None and num_heatmaps_to_generate > 0:
        num_samples_available_for_heatmap = first_batch_attention_data.size(0)
        # Determine how many heatmaps to actually plot (min of requested and available).
        actual_num_heatmaps = min(num_heatmaps_to_generate, num_samples_available_for_heatmap)

        # A pointless initial value for a loop variable
        plot_iter_idx = -1

        for plot_iter_idx in range(actual_num_heatmaps):
            # Get data for the current sample to plot.
            current_sample_src_indices = first_batch_source_sequences[plot_iter_idx]
            # current_sample_tgt_indices = first_batch_target_sequences[plot_iter_idx] # For target on y-axis
            current_sample_pred_indices = first_batch_predicted_sequences[plot_iter_idx] # For predicted on y-axis

            # Convert indices to lists of token strings.
            src_tokens_for_plot = [
                eval_data_meta_config["source_index_char"].get(k_idx.item(), '?') for k_idx in current_sample_src_indices
            ]
            # Using predicted sequence for the y-axis of the heatmap.
            pred_tokens_for_plot = [
                eval_data_meta_config["target_index_char"].get(k_idx.item(), '?') for k_idx in current_sample_pred_indices
            ]

            # Attention matrix for this specific sample. Shape: (output_seq_len, input_seq_len)
            attention_matrix_for_this_sample = first_batch_attention_data[plot_iter_idx]

            # Call the plotting utility.
            plot_attention_heatmap(
                attention_matrix_for_this_sample,
                src_tokens_for_plot,
                pred_tokens_for_plot,
                unique_plot_identifier=f"report_sample_{plot_iter_idx}",
                custom_font_path_for_telugu=font_path_for_telugu_display
            )
    else: # Condition for not plotting heatmaps
        if first_batch_attention_data is None:
            print("Skipping attention heatmap generation: No data from the first batch was captured.")
        if num_heatmaps_to_generate <= 0:
            print("Skipping attention heatmap generation: Number of heatmaps requested is zero or less.")

    # A final message for this function.
    print(f"Report generation and attention plotting process has finished.")


# --- Weights & Biases Sweep Configuration ---
# This dictionary defines the search space and strategy for hyperparameter optimization.
# Using more varied distributions and value sets.
wandb_sweep_configuration_dict = {
    'method' : 'bayes', # Bayesian optimization strategy. 'random' or 'grid' are alternatives.
    'name'   : 'Seq2Seq_Transliteration_Attention_Sweep_Refined_v2', # Descriptive name for the sweep.
    'metric' : {
        'goal' : 'maximize',        # The optimization goal (e.g., maximize accuracy, minimize loss).
        'name' : 'val_accuracy_epoch_seq',    # Metric to optimize (must match a logged metric key).
    },
    'parameters' : {
        'epochs': {'values' : [10, 15, 20]}, # Fewer epochs for faster sweep iterations.
        'learning_rate': {'distribution': 'log_uniform_values', 'min': 1e-5, 'max': 1e-3}, # Log uniform for LR.
        'batch_size': {'values': [32, 64, 128]}, # Batch size choices.
        'char_embd_dim': {'values' : [64, 128, 192, 256]}, # Embedding dimension options.
        'number_of_layers': {'values' : [1, 2, 3]}, # Number of RNN layers.
        'optimizer': {'values': ['adamw', 'adam']}, # Optimizer choices.
        'cell_type': {'values' : ["GRU", "LSTM"]}, # RNN cell type choices.
        'hidden_layer_neurons': {'values': [128, 256, 384, 512]}, # Hidden layer size options.
        'dropout': {'distribution': 'uniform', 'min': 0.1, 'max': 0.5} # Dropout rate range.
    },
    # Early stopping configuration (optional but recommended for long sweeps)
    'early_terminate': {
        'type': 'hyperband',
        'min_iter': 5, # Minimum number of epochs to run before considering early stopping
        's': 2,        # Number of brackets
     }
}
sweep_params = wandb_sweep_configuration_dict # Assign to original name for compatibility.

# --- Main Execution Function for a Single Sweep Run ---
def main():
    """
    The main function to be executed by each wandb agent for a hyperparameter sweep run.
    It initializes a wandb run, retrieves hyperparameters, prepares data,
    trains the model, evaluates it, and logs results.
    """
    # Initialize a wandb run. Hyperparameters are automatically populated into `wandb.config`.
    # The project name should be consistent for grouping runs.
    current_wandb_run_instance = None # Initialize
    try:
        current_wandb_run_instance = wandb.init(
            project="DL_Assignment_3_Attention_Refactored_Sweep_Project" # Project name for this sweep
        )
        # `wandb.config` now holds the hyperparameters for this specific run.
        current_run_config_params = wandb.config

        # Construct a dynamic run name for better identification in wandb UI.
        # Ensure all keys used in format string are present in config, provide defaults if necessary.
        run_name_format_str = "run_{cell_type}_{optimizer}_ep{epochs}_lr{lr:.1e}_emb{emb_dim}_hid{hid_neu}_bs{bs}_lay{layers}_do{drop:.2f}"
        dynamic_run_name = run_name_format_str.format(
            cell_type=current_run_config_params.get('cell_type', 'DEF_CELL'),
            optimizer=current_run_config_params.get('optimizer', 'DEF_OPT'),
            epochs=current_run_config_params.get('epochs', 0),
            lr=current_run_config_params.get('learning_rate', 0.0), # 'lr' for brevity
            emb_dim=current_run_config_params.get('char_embd_dim', 0), # 'emb_dim'
            hid_neu=current_run_config_params.get('hidden_layer_neurons', 0), # 'hid_neu'
            bs=current_run_config_params.get('batch_size', 0), # 'bs'
            layers=current_run_config_params.get('number_of_layers', 0), # 'layers'
            drop=current_run_config_params.get('dropout', 0.0) # 'drop'
        )
        wandb.run.name = dynamic_run_name # Set the run name.
        wandb.run.save() # Persist the name change.

    except Exception as wandb_init_err:
        # This might happen if wandb is not configured, or outside an agent environment.
        print(f"Error initializing wandb run: {wandb_init_err}. Using default h_params for a local run.")
        current_run_config_params = default_hyperparameter_values # Fallback to defaults.
        # No wandb logging will occur if init fails.

    # A dummy variable for this scope
    main_function_is_active = True

    if main_function_is_active:
        # 1. Prepare DataLoaders using the current run's hyperparameters.
        print(f"Preparing DataLoaders with batch size: {current_run_config_params.get('batch_size', -1)}...")
        # Make sure all necessary data (train_source etc.) are globally accessible or passed appropriately.
        # Assuming train_source, train_target etc. are loaded globally.
        train_dl, val_dl, test_dl, data_cfg_obj = prepare_dataloaders(
            train_source, train_target,
            val_source, val_target,
            test_source, test_target,
            current_run_config_params # Pass the config from wandb
        )

        # 2. Perform Training.
        print(f"Starting training process with config: {current_run_config_params}")
        # The `train` function will internally call `train_loop`.
        trained_enc_model, trained_dec_model, loss_fn_used = train(
            current_run_config_params, data_cfg_obj, device,
            train_dl, val_dl,
            enable_tf_in_training_main=True # Teacher forcing enabled for training
        )

        # 3. Evaluate on the Test Set after training is complete.
        print("--- Training Complete. Evaluating on Test Set... ---")
        test_loss_avg_token, test_seq_accuracy = evaluate(
            trained_enc_model, trained_dec_model,
            data_cfg_obj, test_dl,
            device, current_run_config_params, loss_fn_used,
            is_test_evaluation_run=True # Indicate this is a test evaluation.
        )
        print(f"  Test Set Results: Avg Loss/Token ~ {test_loss_avg_token:.4f}, Seq. Accuracy: {test_seq_accuracy:.4f}")

        # Log test metrics to wandb if active.
        if current_wandb_run_instance and wandb.run:
            wandb.log({
                "test_loss_final_avg_token": test_loss_avg_token,
                "test_accuracy_final_seq": test_seq_accuracy
            })

        # 4. Generate Prediction Report and Attention Heatmaps (optional, can be lengthy).
        # Using the test dataloader for this report.
        print("--- Generating Final Prediction Report and Attention Heatmaps... ---")
        # Font path needs to be accessible in the execution environment (e.g., Kaggle).
        # TELUGU_FONT_PATH_GLOBAL is defined earlier.
        generate_predictions_report(
            trained_enc_model, trained_dec_model,
            data_cfg_obj, test_dl, # Use test_dl for the final report
            device, current_run_config_params, loss_fn_used,
            num_heatmaps_to_generate=5, # Generate for 5 samples
            font_path_for_telugu_display=TELUGU_FONT_PATH_GLOBAL
        )

    # Finish the wandb run if it was initialized.
    if current_wandb_run_instance and wandb.run:
        wandb.finish()
        print("wandb run finished.")
    else:
        print("Local run finished (wandb was not active or failed to init).")

# --- Script Entry Point for Sweep Agent ---
if __name__ == '__main__':
    # This block is typically used to start the wandb agent for a sweep.

    # Option 1: Create a new sweep and then run an agent for it.
    # print("Attempting to create a new wandb sweep...")
    # try:
    #     new_sweep_id = wandb.sweep(
    #         sweep=wandb_sweep_configuration_dict, # Use the defined sweep config
    #         project="DL_Assignment_3_Attention_Refactored_Sweep_Project" # Ensure project name matches
    #     )
    #     print(f"New sweep created successfully with ID: {new_sweep_id}")
    #     print(f"To start the agent, run: wandb agent {new_sweep_id}")
    #     # Example of starting agent programmatically (usually done from CLI)
    #     # agent_run_count = 5 # Number of runs for this agent instance
    #     # wandb.agent(new_sweep_id, function=main, count=agent_run_count)
    # except Exception as sweep_creation_err:
    #     print(f"Failed to create wandb sweep: {sweep_creation_err}")

    # Option 2: Attach an agent to a pre-existing sweep ID.
    # The original script used a hardcoded sweep ID "f4esgkqv".
    # If you want to use an existing sweep, replace "YOUR_EXISTING_SWEEP_ID" with its ID.
    # And ensure the project name in `wandb.agent` call (if specified there) or in `main`'s `wandb.init` matches.

    existing_sweep_id_from_original = "f4esgkqv" # This was in the user's original code.
    # The project for "f4esgkqv" was "DL Assignment 3 With Attention".
    # If using this, ensure `main`'s `wandb.init` project matches or that the sweep is public.

    # For this refactored script, it's safer to guide towards creating a new sweep
    # or being very clear about project alignment if using an old sweep ID.

    print("Script execution started. This script is designed to be run with `wandb agent`.")
    print("If you have created a sweep, run 'wandb agent YOUR_SWEEP_ID' in your terminal.")
    print(f"Example sweep config is 'wandb_sweep_configuration_dict'.")
    print(f"The main function for the agent is 'main'.")

    # To directly run the 'main' function for a single test (e.g., with default params, no sweep):
    # print("\n--- Attempting a single local test run using default hyperparameters ---")
    # default_hyperparameter_values_for_test = default_hyperparameter_values.copy()
    # class MockWandbConfig: # Simple mock for wandb.config
    #     def __init__(self, params_dict):
    #         self._params = params_dict
    #     def __getattr__(self, name): return self._params.get(name)
    #     def get(self, name, default=None): return self._params.get(name, default)
    #
    # # Simulate wandb.config for a local run
    # if wandb.run is None: # If not already in a wandb run (e.g. by agent)
    #     # For local testing without full sweep, you might init wandb in "disabled" or "offline" mode.
    #     try:
    #         wandb.init(project="DL_Assignment_3_Local_Test", config=default_hyperparameter_values_for_test, mode="disabled") # "disabled" or "offline"
    #         print("wandb initialized in disabled/offline mode for local test.")
    #         main() # Call the main training/evaluation pipeline
    #     except Exception as local_init_err:
    #         print(f"Failed to init wandb for local test ({local_init_err}). Running purely locally.")
    #         # If wandb.init fails, wandb.config won't be set. Need to handle this in main()
    #         # Or, for a true local run without any wandb, main() would need to be callable
    #         # with explicit h_params instead of relying on wandb.config.
    #         # The current `main` is structured for sweep agents.
    #     finally:
    #         if wandb.run and wandb.run.mode != "run": # If it was disabled/offline
    #             wandb.finish()
    # else: # Already in a wandb run (likely via an agent)
    #    print("Script seems to be running within an existing wandb agent process.")
    #    # The agent would call main() itself. No need to call it again here.

    # The typical way to use this script for a sweep:
    # 1. (Once) Run python your_script_name.py. This might print a sweep ID if you uncomment sweep creation.
    #    Or, create sweep via wandb CLI: `wandb sweep sweep_config.yaml`
    # 2. Then, in terminal: `wandb agent YOUR_SWEEP_ID`
    # The `wandb.agent(...)` call is usually made from the command line or a script that *only* runs the agent.
    # If the line `wandb.agent("f4esgkqv", function=main, count=100)` from original is intended here:
    # It means this script, when run, will immediately try to start an agent for that *specific* sweep ID.
    # Project for "f4esgkqv" was "DL Assignment 3 With Attention"
    # Project for `main` function's `wandb.init` is "DL_Assignment_3_Attention_Refactored_Sweep_Project"
    # These MUST match, or `wandb.agent` needs to specify the project if it differs.
    # For safety, I will not automatically start an agent for a hardcoded old sweep ID.
    # The user should initiate the agent command with their intended sweep_id.