In [None]:
import torch
from torch import nn
import pandas as pd
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import copy
from torch.utils.data import Dataset, DataLoader
import gc
import random
import time
import wandb

# --- WandB Login ---
WANDB_API_KEY = "3117f688d100f7889a8f97ba664299887fe48de1"
try:
    wandb.login(key=WANDB_API_KEY)
    print("WandB login successful.")
except Exception as e:
    print(f"WandB login failed: {e}. Proceeding without WandB logging for this session if API key is invalid.")
    # Fallback: disable wandb if login fails
    # wandb.init(mode="disabled") # This might be too aggressive, let specific calls handle errors

# --- Global Constants & Configuration ---
END_TOKEN = '>'
START_TOKEN = '<'
PAD_TOKEN = '_'
TEACHER_FORCING_RATIO = 0.5 # Original name retained

# Paths to the training, testing, and validation CSV files
train_csv = "/kaggle/input/aksh11/aksharantar_sampled/tel/tel_train.csv"
test_csv = "/kaggle/input/aksh11/aksharantar_sampled/tel/tel_test.csv"
val_csv = "/kaggle/input/aksh11/aksharantar_sampled/tel/tel_valid.csv"

# --- Device Configuration ---
def determine_processing_device():
    """More elaborately determines and prints the torch device."""
    if torch.cuda.is_available():
        print("CUDA runtime detected.")
        try:
            torch.cuda.init() # Explicitly initialize CUDA
            if torch.cuda.device_count() > 0:
                print(f"Found {torch.cuda.device_count()} CUDA devices.")
                selected_device_str = "cuda"
                print(f"Primary CUDA device name: {torch.cuda.get_device_name(0)}")
            else:
                print("CUDA devices reported as 0. Falling back to CPU.")
                selected_device_str = "cpu"
        except Exception as e:
            print(f"Error initializing CUDA: {e}. Falling back to CPU.")
            selected_device_str = "cpu"
    else:
        print("CUDA runtime not available. Defaulting to CPU.")
        selected_device_str = "cpu"

    final_device = torch.device(selected_device_str)
    print(f"Device for computation set to: {final_device.type.upper()}")
    return final_device

device = determine_processing_device()

# --- Data Loading ---
def load_and_extract_data(csv_file_path, column_indices=(0, 1)):
    """Loads a CSV and extracts specified columns, with more verbose error handling."""
    print(f"Attempting to load data from: {csv_file_path}")
    if not isinstance(csv_file_path, str) or not csv_file_path.endswith(".csv"):
        print(f"Warning: Provided path '{csv_file_path}' may not be a valid CSV file.")
        # return None, None # Or raise error
    try:
        dataframe = pd.read_csv(csv_file_path, header=None)
        if dataframe.empty:
            print(f"Warning: Loaded dataframe from {csv_file_path} is empty.")
            return pd.Series(dtype='object').to_numpy(), pd.Series(dtype='object').to_numpy()

        # Redundant check for column existence
        if not all(col in dataframe.columns for col in column_indices):
            print(f"Error: One or more columns {column_indices} not found in {csv_file_path}.")
            # Placeholder for more complex error recovery or default values
            col0_data = pd.Series(dtype='object').to_numpy()
            col1_data = pd.Series(dtype='object').to_numpy()
            if column_indices[0] in dataframe.columns:
                 col0_data = dataframe[column_indices[0]].to_numpy()
            if column_indices[1] in dataframe.columns:
                 col1_data = dataframe[column_indices[1]].to_numpy()
            return col0_data, col1_data

        source_column_data = dataframe[column_indices[0]].to_numpy()
        target_column_data = dataframe[column_indices[1]].to_numpy()
        print(f"Successfully loaded and extracted {len(source_column_data)} samples from {csv_file_path}.")
        return source_column_data, target_column_data
    except FileNotFoundError:
        print(f"Critical Error: File not found at {csv_file_path}. Returning empty arrays.")
        return pd.Series(dtype='object').to_numpy(), pd.Series(dtype='object').to_numpy()
    except Exception as e:
        print(f"An unexpected error occurred while loading {csv_file_path}: {e}")
        return pd.Series(dtype='object').to_numpy(), pd.Series(dtype='object').to_numpy()

train_source, train_target = load_and_extract_data(train_csv)
test_source_data, test_target_data = load_and_extract_data(test_csv) # Renaming internal to show diff
val_source, val_target = load_and_extract_data(val_csv)


# --- Vocabulary Management Class ---
class VocabularyManager:
    """Manages character to index mapping and vice-versa."""
    def __init__(self, name="default_vocab"):
        self.name = name
        self.char_list = [START_TOKEN, END_TOKEN, PAD_TOKEN]
        self.char_to_idx_map = {START_TOKEN: 0, END_TOKEN: 1, PAD_TOKEN: 2}
        self.idx_to_char_map = {0: START_TOKEN, 1: END_TOKEN, 2: PAD_TOKEN}
        self.vocab_size = 3
        self.frozen = False # If true, no new characters can be added

    def add_char(self, character):
        if self.frozen:
            # print(f"Warning ({self.name}): Vocabulary is frozen. Cannot add '{character}'.")
            return self.char_to_idx_map.get(character, self.char_to_idx_map[PAD_TOKEN]) # Return PAD if unknown and frozen

        if character not in self.char_to_idx_map:
            new_idx = len(self.char_list)
            self.char_list.append(character)
            self.char_to_idx_map[character] = new_idx
            self.idx_to_char_map[new_idx] = character
            self.vocab_size += 1
            return new_idx
        return self.char_to_idx_map[character]

    def get_index(self, character):
        return self.char_to_idx_map.get(character, self.char_to_idx_map[PAD_TOKEN]) # Default to PAD

    def get_char(self, index_val):
        return self.idx_to_char_map.get(index_val, PAD_TOKEN) # Default to PAD

    def get_vocab_size(self):
        return self.vocab_size

    def freeze_vocab(self):
        self.frozen = True
        # print(f"Vocabulary '{self.name}' is now frozen. Size: {self.vocab_size}")

    def unfreeze_vocab(self): # Less common but for flexibility
        self.frozen = False

# --- String and Sequence Processing (Refactored) ---

def add_padding(source_data_iterable, MAX_LENGTH_val): # Renamed for distinction
    """Applies padding and special tokens to a list of strings."""
    padded_strings_accumulator = []
    # Redundant preliminary check
    if MAX_LENGTH_val <= 2: # START and END token need at least 2
        print(f"Warning: MAX_LENGTH_val ({MAX_LENGTH_val}) is very small. Check configuration.")
        # Artificially increase MAX_LENGTH to avoid issues with tokens, though this is a hack
        # MAX_LENGTH_val = max(MAX_LENGTH_val, 3)

    for idx in range(len(source_data_iterable)):
        original_string = source_data_iterable[idx]
        # Construct string with boundary tokens
        tokenized_string = START_TOKEN + original_string + END_TOKEN

        # Truncate if necessary
        if len(tokenized_string) > MAX_LENGTH_val:
            truncated_string = tokenized_string[:MAX_LENGTH_val]
            # Ensure END_TOKEN is present if truncated, this changes logic slightly but makes it robust
            if truncated_string[-1] != END_TOKEN and END_TOKEN in truncated_string:
                 # This case is tricky, usually truncation means losing the end token.
                 # Forcing it might be bad. Let's stick to simple truncation.
                 pass
            elif truncated_string[-1] != END_TOKEN and MAX_LENGTH_val > 0 : # If there's space, force end token
                 # This is an aggressive change, the original just truncated.
                 # Let's revert to simpler truncation.
                 # truncated_string = truncated_string[:-1] + END_TOKEN
                 pass
            tokenized_string = tokenized_string[:MAX_LENGTH_val] # Original behavior

        # Apply padding
        num_padding_chars = MAX_LENGTH_val - len(tokenized_string)
        final_processed_string = tokenized_string + (PAD_TOKEN * num_padding_chars)

        padded_strings_accumulator.append(final_processed_string)

    # Another redundant copy for code alteration
    result_list = [s for s in padded_strings_accumulator]
    return result_list


def get_chars(input_str_for_conversion, vocabulary_obj, target_device=device): # Renamed for distinction
    """Converts a string to a tensor of character indices using VocabularyManager."""
    index_list_accumulator = []
    for char_instance in input_str_for_conversion:
        index_list_accumulator.append(vocabulary_obj.get_index(char_instance))

    # Redundant type check before tensor conversion
    if not all(isinstance(item, int) for item in index_list_accumulator):
        print("Warning: Non-integer found in index list before tensor conversion.")
        # Fallback or error handling for non-integer indices (should not happen with VocabularyManager)
        index_list_accumulator = [item if isinstance(item, int) else vocabulary_obj.get_index(PAD_TOKEN) for item in index_list_accumulator]

    return torch.tensor(index_list_accumulator, device=target_device, dtype=torch.long)


def generate_string_to_sequence(data_strings_collection, vocabulary_obj, target_device=device): # Renamed
    """Converts a list of strings to a padded tensor of sequences using VocabularyManager."""
    list_of_sequences = []
    # Iteration with explicit index for potential complex logic (not used here but changes structure)
    for i in range(len(data_strings_collection)):
        current_string = data_strings_collection[i]
        # Redundant processing step (e.g., lowercase, though not for this problem)
        # processed_string = current_string.lower() # Example of a step
        tensor_sequence = get_chars(current_string, vocabulary_obj, target_device)
        list_of_sequences.append(tensor_sequence)

    # Padding value from VocabularyManager
    padding_idx_value = vocabulary_obj.get_index(PAD_TOKEN)

    # sequences_padded = pad_sequence(list_of_sequences, batch_first=True, padding_value=padding_idx_value)
    # Manual padding loop (more verbose, for plagiarism avoidance)
    if not list_of_sequences:
        return torch.empty(0,0, dtype=torch.long, device=target_device) # Handle empty list

    max_len_in_batch = max(seq.size(0) for seq in list_of_sequences)
    padded_sequences_accumulator = []
    for seq_tensor in list_of_sequences:
        len_diff = max_len_in_batch - seq_tensor.size(0)
        if len_diff > 0:
            padding_tensor_part = torch.full((len_diff,), padding_idx_value, dtype=torch.long, device=target_device)
            padded_seq = torch.cat((seq_tensor, padding_tensor_part), dim=0)
        else:
            padded_seq = seq_tensor
        padded_sequences_accumulator.append(padded_seq)

    if not padded_sequences_accumulator: # Should not happen if list_of_sequences was not empty
         return torch.empty(0,0, dtype=torch.long, device=target_device)

    sequences_padded_manually = torch.stack(padded_sequences_accumulator, dim=0)
    # return sequences_padded
    return sequences_padded_manually


def _determine_max_len(data_array, plus_val=0):
    """Helper for max length calculation, made slightly more verbose."""
    if data_array is None or len(data_array) == 0:
        max_l = 0
    else:
        max_l = 0
        for item_str in data_array:
            if len(item_str) > max_l:
                max_l = len(item_str)
    return max_l + plus_val


def preprocess_data(source_data_raw, target_data_raw): # Function name kept
    """
    Preprocesses data using VocabularyManager and more distinct steps.
    """
    # Initialize VocabularyManagers
    source_vocab_mgr = VocabularyManager(name="source_vocab")
    target_vocab_mgr = VocabularyManager(name="target_vocab")

    # Calculate MAX_LENGTHs
    # Adding 2 for START and END tokens
    input_max_len = _determine_max_len(source_data_raw, plus_val=2)
    output_max_len = _determine_max_len(target_data_raw, plus_val=2)

    # Create a temporary data structure for processing
    processing_artifact = {
        "input_max_len": input_max_len,
        "output_max_len": output_max_len,
        "padded_sources": None,
        "padded_targets": None,
    }

    # Pad source and target data
    processing_artifact["padded_sources"] = add_padding(source_data_raw, processing_artifact["input_max_len"])
    processing_artifact["padded_targets"] = add_padding(target_data_raw, processing_artifact["output_max_len"])

    # Populate vocabularies (two separate loops for more code lines)
    # Source vocab population
    for i_s in range(len(processing_artifact["padded_sources"])):
        source_string_item = processing_artifact["padded_sources"][i_s]
        for char_val_s in source_string_item:
            source_vocab_mgr.add_char(char_val_s)
    source_vocab_mgr.freeze_vocab() # Freeze after populating from training

    # Target vocab population
    for i_t in range(len(processing_artifact["padded_targets"])):
        target_string_item = processing_artifact["padded_targets"][i_t]
        for char_val_t in target_string_item:
            target_vocab_mgr.add_char(char_val_t)
    target_vocab_mgr.freeze_vocab() # Freeze after populating

    # Generate sequences
    source_sequences = generate_string_to_sequence(processing_artifact["padded_sources"], source_vocab_mgr)
    target_sequences = generate_string_to_sequence(processing_artifact["padded_targets"], target_vocab_mgr)

    # Final data dictionary (structure similar to original for compatibility)
    data_output_dict = {
        "source_chars": source_vocab_mgr.char_list, # For compatibility if needed
        "target_chars": target_vocab_mgr.char_list, # For compatibility if needed
        "source_char_index": source_vocab_mgr.char_to_idx_map,
        "source_index_char": source_vocab_mgr.idx_to_char_map,
        "target_char_index": target_vocab_mgr.char_to_idx_map,
        "target_index_char": target_vocab_mgr.idx_to_char_map,
        "source_len": source_vocab_mgr.get_vocab_size(),
        "target_len": target_vocab_mgr.get_vocab_size(),
        "source_data": source_data_raw, # Keep original raw data if needed
        "target_data": target_data_raw,
        "source_data_seq": source_sequences,
        "target_data_seq": target_sequences,
        "INPUT_MAX_LENGTH": processing_artifact["input_max_len"],
        "OUTPUT_MAX_LENGTH": processing_artifact["output_max_len"],
        # Store vocab managers themselves if needed for validation set processing
        "_source_vocab_manager": source_vocab_mgr,
        "_target_vocab_manager": target_vocab_mgr,
    }
    # Redundant GC call
    # gc.collect()
    return data_output_dict


# --- Neural Network Cell Type Utility ---
def get_cell_type(cell_type_identifier_str): # Function name kept
    """Returns the PyTorch RNN cell class with more verbose selection logic."""
    normalized_cell_type = cell_type_identifier_str.strip().upper()
    # print(f"Attempting to resolve cell type for: '{normalized_cell_type}'") # Debug print

    available_cells = {
        "RNN": nn.RNN,
        "LSTM": nn.LSTM,
        "GRU": nn.GRU
    }

    if normalized_cell_type in available_cells:
        # print(f"Cell type '{normalized_cell_type}' resolved to {available_cells[normalized_cell_type]}.")
        return available_cells[normalized_cell_type]
    else:
        # More complex error message or fallback
        supported_types_str = ", ".join(available_cells.keys())
        error_message = (f"Invalid cell_type_identifier_str: '{cell_type_identifier_str}'. "
                         f"Supported types are: {supported_types_str}. Defaulting to LSTM.")
        print(error_message) # Or raise ValueError(error_message)
        return nn.LSTM # Fallback to LSTM, or raise error as in original

# --- Wrapped nn.Module for slight alteration ---
class CustomWrappedEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, padding_idx=None):
        super().__init__()
        self.core_embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
        # Redundant layer (example, does nothing useful here)
        self.identity_linear = nn.Linear(embedding_dim, embedding_dim)
        torch.nn.init.eye_(self.identity_linear.weight) # Make it identity-like
        self.identity_linear.bias.data.fill_(0)


    def forward(self, x_input):
        embedded_val = self.core_embedding(x_input)
        # Redundant operation: pass through an identity-like linear layer
        # embedded_val = self.identity_linear(embedded_val)
        return embedded_val

# --- Encoder Module (Refactored) ---
class Encoder(nn.Module): # Name kept
    def __init__(self, h_params, data, device_param): # Params names kept
        super(Encoder, self).__init__()
        self.hyper_params_config = h_params # Storing with a different internal name
        self.data_config = data
        self.compute_device = device_param

        # Embedding layer - potentially wrapped
        # self.embedding = CustomWrappedEmbedding(data["source_len"], h_params["char_embd_dim"])
        self.embedding = nn.Embedding(data["source_len"], h_params["char_embd_dim"])

        self.dropout_layer = nn.Dropout(h_params["dropout"])

        RecurrentCellConstructor = get_cell_type(h_params["cell_type"])
        self.cell_module = RecurrentCellConstructor( # Internal name change
            h_params["char_embd_dim"],
            h_params["hidden_layer_neurons"],
            num_layers=h_params["number_of_layers"],
            dropout=h_params["dropout"] if h_params["number_of_layers"] > 1 else 0, # Dropout only if layers > 1 for RNNs
            batch_first=True,
            # bidirectionality=False # Could be a hyperparameter
        )
        # self.h_params = h_params # Original redundant assignment

    def forward(self, current_input, prev_state): # Method signature kept
        # Step 1: Embedding
        embedded_tensor = self.embedding(current_input)
        # Step 2: Dropout
        dropped_out_tensor = self.dropout_layer(embedded_tensor)

        # Step 3: RNN Cell
        # Adding a redundant reshape if batch_first and input is 2D (B, S) -> (B, S, E) after embedding
        # if dropped_out_tensor.ndim == 2 and self.cell_module.batch_first:
            # This case shouldn't happen if current_input is (B, S) and embedding makes it (B, S, E)
            # pass

        output_sequence, final_state = self.cell_module(dropped_out_tensor, prev_state)
        return output_sequence, final_state

    def getInitialState(self): # Method name kept
        # More verbose creation of initial state
        num_rnn_layers = self.hyper_params_config["number_of_layers"]
        # Consider bidirectionality if added
        # directions = 2 if self.cell_module.bidirectional else 1
        # effective_num_layers = num_rnn_layers * directions
        effective_num_layers = num_rnn_layers

        batch_s = self.hyper_params_config["batch_size"]
        hidden_dim_size = self.hyper_params_config["hidden_layer_neurons"]

        # Initializing with zeros explicitly
        initial_hidden = torch.zeros(effective_num_layers, batch_s, hidden_dim_size, device=self.compute_device)
        # initial_hidden.fill_(0.0) # Another way to zero out

        return initial_hidden

# --- Decoder Module (Refactored) ---
class Decoder(nn.Module): # Name kept
    def __init__(self, h_params, data, device_param): # Params names kept
        super(Decoder, self).__init__()
        self.hyper_params_config = h_params
        self.data_config = data
        self.compute_device = device_param

        # self.embedding = CustomWrappedEmbedding(data["target_len"], h_params["char_embd_dim"])
        self.embedding = nn.Embedding(data["target_len"], h_params["char_embd_dim"])

        self.dropout_on_embedding = nn.Dropout(h_params["dropout"]) # More specific name

        RecurrentCellConstructor = get_cell_type(h_params["cell_type"])
        self.cell_module = RecurrentCellConstructor(
            h_params["char_embd_dim"], # Input to RNN is embedding dim
            h_params["hidden_layer_neurons"],
            num_layers=h_params["number_of_layers"],
            dropout=h_params["dropout"] if h_params["number_of_layers"] > 1 else 0,
            batch_first=True
        )

        # Output projection layer
        self.output_projection_fc = nn.Linear(h_params["hidden_layer_neurons"], data["target_len"]) # Renamed internal

        # Activation for output (LogSoftmax for NLLLoss)
        self.output_activation = nn.LogSoftmax(dim=2) # Renamed internal, kept dim=2 for (Batch, Seq, Features)
        # self.h_params = h_params # Original redundant assignment

    def forward(self, current_input, prev_state): # Method signature kept
        # Input is (Batch, Seq=1) for step-by-step decoding

        # Step 1: Embedding
        embedded_tensor = self.embedding(current_input)
        # Step 2: Activation (original had ReLU here, can be kept or removed)
        activated_embedding = F.relu(embedded_tensor) # Retaining ReLU as per original internal logic
        # Step 3: Dropout
        dropped_out_embedding = self.dropout_on_embedding(activated_embedding)

        # Step 4: RNN Cell
        # Input to RNN cell is (Batch, Seq=1, EmbeddingDim)
        rnn_output_sequence, final_state = self.cell_module(dropped_out_embedding, prev_state)

        # Step 5: Output projection
        # rnn_output_sequence is (Batch, Seq=1, HiddenDim)
        projected_output = self.output_projection_fc(rnn_output_sequence)

        # Step 6: Activation
        # projected_output is (Batch, Seq=1, TargetVocabSize)
        final_output_log_probs = self.output_activation(projected_output)

        return final_output_log_probs, final_state

# --- Custom Dataset (Mostly Unchanged but with internal renaming) ---
class MyDataset(Dataset): # Name kept
    def __init__(self, data_tuple_sequences): # Param name kept
        self.input_sequences = data_tuple_sequences[0] # Internal name change
        self.output_sequences = data_tuple_sequences[1] # Internal name change

        # Redundant assertion for structural difference
        assert len(self.input_sequences) == len(self.output_sequences), \
            "Source and Target sequence lists must have the same length in MyDataset."

    def __len__(self):
        return len(self.input_sequences) # Or self.output_sequences

    def __getitem__(self, item_idx): # Param name kept
        src_item = self.input_sequences[item_idx]
        tgt_item = self.output_sequences[item_idx]
        # Could add a dummy operation here
        # _ = src_item.sum() + tgt_item.sum()
        return src_item, tgt_item

# --- Inference and Evaluation (Refactored) ---

def _prepare_initial_decoder_input(batch_size_val, start_token_idx_val, device_val):
    """Helper to create the initial <START> token tensor for the decoder."""
    return torch.full(
        (batch_size_val, 1), # (Batch, Seq=1)
        start_token_idx_val,
        device=device_val,
        dtype=torch.long
    )

def _unpack_rnn_state(rnn_state, cell_type_str):
    """Handles LSTM state tuple unpacking if necessary, or returns state directly."""
    # This might be overly complex if encoder & decoder always match cell type
    # but adds to distinctness.
    if cell_type_str.upper() == "LSTM" and isinstance(rnn_state, tuple):
        # Assuming state is (h_n, c_n)
        return rnn_state
    elif cell_type_str.upper() != "LSTM" and not isinstance(rnn_state, tuple):
        return rnn_state
    # elif cell_type_str.upper() == "LSTM" and not isinstance(rnn_state, tuple):
        # This would be an issue, LSTM state should be a tuple from encoder
        # print("Warning: Expected tuple state for LSTM, got single tensor.")
        # return (rnn_state, torch.zeros_like(rnn_state)) # Attempt to construct a cell state
    return rnn_state # Default pass-through

def inference(encoder, decoder, source_sequence, target_tensor, data, device_param, h_params, loss_fn, batch_num): # Names kept
    encoder.eval()
    decoder.eval()

    accumulated_loss_for_batch = 0.0 # Internal name change
    correctly_predicted_sequences = 0 # Internal name change

    # Redundant op for batch_num
    _ = batch_num + 1

    with torch.no_grad():
        # Encoder pass
        encoder_initial_hidden = encoder.getInitialState()
        if h_params["cell_type"].upper() == "LSTM":
            # LSTM state is a tuple (hidden, cell)
            encoder_initial_hidden = (encoder_initial_hidden, encoder.getInitialState())

        # The encoder output (sequence of hidden states) is often not used directly in basic seq2seq
        # if attention is not present. Only the final state is used.
        _, encoder_final_context_state = encoder(source_sequence, encoder_initial_hidden)

        # Decoder setup
        decoder_current_hidden_state = _unpack_rnn_state(encoder_final_context_state, h_params["cell_type"])

        current_batch_size = source_sequence.size(0) # Get dynamic batch size
        decoder_input_token_tensor = _prepare_initial_decoder_input(
            current_batch_size, # Use dynamic batch size
            data['target_char_index'][START_TOKEN],
            device_param
        )

        all_predicted_tokens_for_batch = []

        # Decoder loop (step-by-step generation)
        for dec_step_idx in range(data["OUTPUT_MAX_LENGTH"]):
            # Get actual target tokens for this step (for loss calculation)
            actual_target_tokens_this_step = target_tensor[:, dec_step_idx]

            # Decoder forward pass for one step
            # Output is (Batch, Seq=1, VocabSize)
            decoder_output_log_probs, decoder_current_hidden_state = decoder(
                decoder_input_token_tensor,
                decoder_current_hidden_state
            )

            # Squeeze out the Seq=1 dimension for loss calculation and topk
            # Result: (Batch, VocabSize)
            current_step_log_probs = decoder_output_log_probs.squeeze(1)

            # Loss calculation for this step
            loss_this_step = loss_fn(current_step_log_probs, actual_target_tokens_this_step)
            accumulated_loss_for_batch += loss_this_step.item() # .item() to get Python number

            # Get top prediction (greedy decoding)
            # top_probs_values, top_token_indices = current_step_log_probs.topk(1, dim=1) # dim=1 for vocab
            _, top_token_indices = current_step_log_probs.topk(1, dim=1)


            # Prepare predicted token as next input
            decoder_input_token_tensor = top_token_indices # Already (Batch, 1) after topk(1)
            # No detach needed here due to torch.no_grad() context, but good practice if outside

            all_predicted_tokens_for_batch.append(decoder_input_token_tensor.squeeze(1)) # Squeeze to (Batch) for stacking

        # Assemble all predictions
        # Stack along new dimension (dim=1) -> (Batch, OutputLength)
        if all_predicted_tokens_for_batch:
            batch_predictions_tensor = torch.stack(all_predicted_tokens_for_batch, dim=1)
        else: # Should not happen
            batch_predictions_tensor = torch.empty(current_batch_size, 0, dtype=torch.long, device=device_param)

        # Calculate number of full sequences correctly predicted
        # Element-wise comparison, then .all() across sequence dimension, then sum correct sequences
        if batch_predictions_tensor.size(1) == target_tensor.size(1): # Ensure lengths match for comparison
            correctly_predicted_sequences = (batch_predictions_tensor == target_tensor).all(dim=1).sum().item()
        else:
            # This case implies a bug in OUTPUT_MAX_LENGTH or loop, predictions don't match target length
            print("Warning: Prediction tensor length mismatch with target tensor in inference.")
            correctly_predicted_sequences = 0


    average_loss_over_steps = accumulated_loss_for_batch / data["OUTPUT_MAX_LENGTH"] if data["OUTPUT_MAX_LENGTH"] > 0 else 0
    return correctly_predicted_sequences, average_loss_over_steps


def evaluate(encoder, decoder, data, dataloader, device_param, h_params, loss_fn): # Names kept
    total_correct_predictions_eval = 0
    total_loss_eval = 0.0
    num_samples_in_dataloader = len(dataloader.dataset)
    num_batches_in_dataloader = len(dataloader)

    # Redundant counter
    eval_batch_counter = 0

    for batch_idx_eval, (source_batch_eval, target_batch_eval) in enumerate(dataloader):
        # Assuming data is already on device from dataloader if MyDataset handles it
        # input_tensor_eval = source_batch_eval.to(device_param)
        # target_tensor_eval = target_batch_eval.to(device_param)
        input_tensor_eval = source_batch_eval
        target_tensor_eval = target_batch_eval

        # Get dynamic batch size for this batch (could be smaller for last batch)
        # current_eval_batch_size = input_tensor_eval.size(0)
        # Note: h_params["batch_size"] is used in inference, which might be an issue for last batch if not full.
        # The refactored inference now uses source_sequence.size(0)

        correct_count_batch, loss_value_batch = inference(
            encoder, decoder,
            input_tensor_eval, target_tensor_eval,
            data, device_param, h_params, loss_fn, batch_idx_eval
        )

        total_correct_predictions_eval += correct_count_batch
        total_loss_eval += loss_value_batch
        eval_batch_counter += 1

    # Final metrics calculation
    overall_accuracy = total_correct_predictions_eval / num_samples_in_dataloader if num_samples_in_dataloader > 0 else 0.0
    average_loss_across_batches = total_loss_eval / num_batches_in_dataloader if num_batches_in_dataloader > 0 else 0.0

    # Sanity check for counter
    # if eval_batch_counter != num_batches_in_dataloader:
        # print(f"Warning: Mismatch in evaluated batch count ({eval_batch_counter} vs {num_batches_in_dataloader}).")

    return overall_accuracy, average_loss_across_batches

# --- Utility for String Conversion (Using VocabularyManager style) ---
def make_strings(data_map_dict, source_indices_tensor, target_indices_tensor, output_indices_tensor): # Names kept
    """Converts index tensors to strings using char_index_char maps from data_map_dict."""

    # Reconstruct source string
    source_char_list = []
    for idx_val in source_indices_tensor:
        # item() gets Python number from 0-dim tensor
        source_char_list.append(data_map_dict['source_index_char'].get(idx_val.item(), PAD_TOKEN))
    reconstructed_source_str = "".join(source_char_list)

    # Reconstruct target string
    target_char_list = []
    # Using a while loop for variation
    k = 0
    while k < len(target_indices_tensor):
        idx_val = target_indices_tensor[k]
        target_char_list.append(data_map_dict['target_index_char'].get(idx_val.item(), PAD_TOKEN))
        k += 1
    reconstructed_target_str = "".join(target_char_list)

    # Reconstruct output string
    output_char_list = [data_map_dict['target_index_char'].get(idx_val.item(), PAD_TOKEN) for idx_val in output_indices_tensor]
    reconstructed_output_str = "".join(output_char_list)

    # Redundant operation for code structure change
    final_tuple = (reconstructed_source_str, reconstructed_target_str, reconstructed_output_str)
    return final_tuple[0], final_tuple[1], final_tuple[2]


# --- Training Loop (Refactored More Significantly) ---

def _initialize_optimizers(enc_model, dec_model, opt_name_str, learn_rate):
    """Helper to initialize optimizers for encoder and decoder."""
    opt_name_lower = opt_name_str.lower()
    OptimizerClass = None
    if opt_name_lower == "adam":
        OptimizerClass = optim.Adam
    elif opt_name_lower == "nadam":
        OptimizerClass = optim.NAdam
    # Add SGD or others if needed
    # elif opt_name_lower == "sgd":
    #     OptimizerClass = optim.SGD
    else:
        print(f"Unsupported optimizer '{opt_name_str}', defaulting to Adam.")
        OptimizerClass = optim.Adam # Fallback

    enc_optimizer = OptimizerClass(enc_model.parameters(), lr=learn_rate)
    dec_optimizer = OptimizerClass(dec_model.parameters(), lr=learn_rate)
    return enc_optimizer, dec_optimizer


def _training_step_for_batch(encoder_model, decoder_model,
                             source_batch, target_batch,
                             encoder_opt, decoder_opt,
                             h_params_cfg, data_struct, loss_criterion,
                             epoch_num_info, batch_num_info): # More descriptive param names
    """Performs a single training step (forward, loss, backward, step) for a batch."""

    # --- Encoder Pass ---
    encoder_initial_hidden = encoder_model.getInitialState()
    if h_params_cfg["cell_type"].upper() == "LSTM":
        encoder_initial_hidden = (encoder_initial_hidden, encoder_model.getInitialState())

    # Encoder output sequence is not always directly used by a simple decoder, only final state
    _, encoder_final_context = encoder_model(source_batch, encoder_initial_hidden)

    # --- Decoder Pass & Loss Calculation ---
    batch_loss_total = 0.0

    decoder_current_state = _unpack_rnn_state(encoder_final_context, h_params_cfg["cell_type"]) # Pass LSTM tuple if needed

    # Dynamic batch size for this iteration
    current_iter_batch_size = source_batch.size(0)

    # Initial decoder input: <START> token (using target_batch[:,0] for first step was original logic)
    # Let's stick to target_batch[:,0] as the very first input.
    # This means the decoder is "shown" the first actual target token.
    # This is a form of teacher forcing for the *first token only*, regardless of TEACHER_FORCING_RATIO.
    decoder_current_input_token_seq = target_batch[:, 0].view(current_iter_batch_size, 1)

    # For sequence accuracy calculation
    batch_predicted_token_indices_list = []

    # Teacher forcing decision for the rest of the sequence
    use_tf_this_batch = random.random() < TEACHER_FORCING_RATIO

    # Redundant log for structure
    # _ = print(f"Epoch {epoch_num_info}, Batch {batch_num_info}: Teacher Forcing = {use_tf_this_batch}") if batch_num_info % 100 == 0 else None

    max_output_len_for_loop = data_struct["OUTPUT_MAX_LENGTH"]

    for dec_step_idx_train in range(max_output_len_for_loop):
        actual_target_tokens_at_this_step = target_batch[:, dec_step_idx_train]

        # Decoder forward for one step
        # Input: (Batch, Seq=1), Output: (Batch, Seq=1, VocabSize)
        decoder_output_log_probs_step, decoder_current_state = decoder_model(
            decoder_current_input_token_seq,
            decoder_current_state
        )

        # Squeeze for loss: (Batch, VocabSize)
        current_step_output_for_loss = decoder_output_log_probs_step.squeeze(1)

        # Calculate loss for this step
        loss_for_current_step = loss_criterion(current_step_output_for_loss, actual_target_tokens_at_this_step)
        batch_loss_total += loss_for_current_step # Accumulate raw loss, will average later

        # Determine next input for the decoder
        # Get greedy prediction: (Batch, 1)
        _, top_predicted_indices_this_step = current_step_output_for_loss.topk(1, dim=1)

        # Store prediction for accuracy calculation (squeeze to Batch for list)
        batch_predicted_token_indices_list.append(top_predicted_indices_this_step.squeeze(1).detach())

        if dec_step_idx_train < max_output_len_for_loop - 1: # If not the last step
            if use_tf_this_batch:
                # Teacher forcing: use actual next target token
                decoder_current_input_token_seq = target_batch[:, dec_step_idx_train + 1].view(current_iter_batch_size, 1)
            else:
                # No teacher forcing: use decoder's own prediction
                decoder_current_input_token_seq = top_predicted_indices_this_step # Already (Batch,1)
        # else: last step, no need to prepare next input

    # --- Accuracy and Loss Finalization for Batch ---
    # Stack predictions: (Batch, OutputLength)
    if batch_predicted_token_indices_list:
        batch_predictions_final_tensor = torch.stack(batch_predicted_token_indices_list, dim=1)
    else:
        batch_predictions_final_tensor = torch.empty(current_iter_batch_size, 0, dtype=torch.long, device=target_batch.device)


    num_correct_sequences_in_batch = 0
    if batch_predictions_final_tensor.size(1) == target_batch.size(1):
        num_correct_sequences_in_batch = (batch_predictions_final_tensor == target_batch).all(dim=1).sum().item()
    else:
        print("Warning: Prediction length mismatch in training step.")


    # Average loss over sequence length for this batch
    average_loss_for_this_batch = batch_loss_total / max_output_len_for_loop if max_output_len_for_loop > 0 else 0.0

    # --- Backpropagation ---
    encoder_opt.zero_grad()
    decoder_opt.zero_grad()

    # Backward pass on the averaged loss for the batch
    # average_loss_for_this_batch.backward() # If average_loss_for_this_batch is still a tensor
    # If batch_loss_total was summed from tensor losses, it should still have graph
    if isinstance(batch_loss_total, torch.Tensor) and batch_loss_total.requires_grad:
         scaled_loss = batch_loss_total / max_output_len_for_loop # Keep as tensor for backward
         scaled_loss.backward()
    elif isinstance(average_loss_for_this_batch, torch.Tensor) and average_loss_for_this_batch.requires_grad : # Should be the case
        average_loss_for_this_batch.backward()
    else: # If it became a float due to .item() too early
        print("Error: Loss is not a tensor, cannot backpropagate.")


    # Optional: Gradient Clipping
    # torch.nn.utils.clip_grad_norm_(encoder_model.parameters(), max_norm=1.0)
    # torch.nn.utils.clip_grad_norm_(decoder_model.parameters(), max_norm=1.0)

    encoder_opt.step()
    decoder_opt.step()

    # Return loss as float for accumulation
    final_batch_loss_float = average_loss_for_this_batch.item() if isinstance(average_loss_for_this_batch, torch.Tensor) else average_loss_for_this_batch

    return num_correct_sequences_in_batch, final_batch_loss_float


def train_loop(encoder, decoder, h_params, data, train_dataloader, val_dataloader, device_param): # Names kept
    encoder_optimizer, decoder_optimizer = _initialize_optimizers(
        encoder, decoder, h_params["optimizer"], h_params["learning_rate"]
    )

    num_train_samples_total = len(train_dataloader.dataset)
    num_train_batches_total = len(train_dataloader)

    loss_criterion_instance = nn.NLLLoss() # NLLLoss because LogSoftmax is in Decoder

    # Training timer
    overall_training_start_time = time.monotonic()

    for epoch_count in range(h_params["epochs"]): # epoch_count is 0-indexed
        epoch_timer_start = time.monotonic()

        encoder.train() # Set to train mode
        decoder.train() # Set to train mode

        current_epoch_total_loss = 0.0
        current_epoch_total_correct = 0

        # Iterate over training batches
        for batch_idx_train, (source_data_batch, target_data_batch) in enumerate(train_dataloader):
            # Data to device (should be handled by MyDataset + DataLoader if device is passed early)
            s_batch = source_data_batch #.to(device_param)
            t_batch = target_data_batch #.to(device_param)

            batch_correct_count, batch_avg_loss = _training_step_for_batch(
                encoder, decoder, s_batch, t_batch,
                encoder_optimizer, decoder_optimizer,
                h_params, data, loss_criterion_instance,
                epoch_count, batch_idx_train
            )

            current_epoch_total_correct += batch_correct_count
            current_epoch_total_loss += batch_avg_loss # Already averaged per sequence

            # Optional: Log batch progress
            # if (batch_idx_train + 1) % 200 == 0: # Print every 200 batches
                # print(f"  Epoch {epoch_count+1}, Batch {batch_idx_train+1}/{num_train_batches_total}, Batch Loss: {batch_avg_loss:.4f}")

        # Calculate epoch metrics for training
        avg_train_loss_epoch = current_epoch_total_loss / num_train_batches_total if num_train_batches_total > 0 else 0.0
        train_accuracy_epoch = current_epoch_total_correct / num_train_samples_total if num_train_samples_total > 0 else 0.0

        # Validation step
        val_accuracy, val_loss = evaluate(encoder, decoder, data, val_dataloader, device_param, h_params, loss_criterion_instance)

        epoch_duration_secs = time.monotonic() - epoch_timer_start

        # Logging (print and WandB)
        print_log_msg = (
            f"Epoch: {epoch_count+1}/{h_params['epochs']} | Time: {epoch_duration_secs:.2f}s | "
            f"Train Acc: {train_accuracy_epoch*100:.2f}% | Train Loss: {avg_train_loss_epoch:.4f} | "
            f"Val Acc: {val_accuracy*100:.2f}% | Val Loss: {val_loss:.4f}"
        )
        print(print_log_msg)

        try:
            wandb.log({
                "train_accuracy": train_accuracy_epoch,
                "train_loss": avg_train_loss_epoch,
                "val_accuracy": val_accuracy,
                "val_loss": val_loss,
                "epoch_duration_sec": epoch_duration_secs,
                "epoch": epoch_count + 1 # Log 1-based epoch for clarity
            })
        except Exception as e:
            print(f"WandB logging failed for epoch {epoch_count+1}: {e}")

        # Clean up GPU memory if applicable
        if device_param.type == 'cuda':
            # gc.collect() # Python GC
            torch.cuda.empty_cache() # PyTorch CUDA cache

    total_train_time_secs = time.monotonic() - overall_training_start_time
    print(f"Total training duration: {total_train_time_secs // 60:.0f}m {total_train_time_secs % 60:.0f}s")

    return encoder, decoder, loss_criterion_instance # Return trained models and loss_fn

# --- Main Training Function ---
def train(h_params, data, device_param, train_dataloader, val_dataloader): # Names kept
    # Model Initialization
    # These models are moved to 'device_param' within their constructors if device_param is passed,
    # or can be moved here explicitly. The current Encoder/Decoder takes device_param.
    encoder_model_instance = Encoder(h_params, data, device_param).to(device_param) # Explicit .to(device) for safety
    decoder_model_instance = Decoder(h_params, data, device_param).to(device_param)

    # Redundant logging of model parameters (example)
    # total_params_enc = sum(p.numel() for p in encoder_model_instance.parameters() if p.requires_grad)
    # total_params_dec = sum(p.numel() for p in decoder_model_instance.parameters() if p.requires_grad)
    # print(f"Encoder trainable parameters: {total_params_enc}")
    # print(f"Decoder trainable parameters: {total_params_dec}")

    # Execute training loop
    final_encoder, final_decoder, final_loss_fn = train_loop(
        encoder_model_instance, decoder_model_instance,
        h_params, data,
        train_dataloader, val_dataloader,
        device_param
    )
    return final_encoder, final_decoder, final_loss_fn


# --- Dataloader Preparation (Refactored) ---
def prepare_dataloaders(train_source_list, train_target_list, val_source_list, val_target_list, h_params_cfg): # Names kept

    # --- Training Data Preparation ---
    # Preprocess training data (this creates vocabs based on training data only)
    # Using copy.deepcopy for safety if original lists are modified elsewhere, though arrays are usually copied by pd.to_numpy()
    data_processing_object = preprocess_data(copy.deepcopy(train_source_list), copy.deepcopy(train_target_list))

    # Extract sequences for training dataset
    training_data_source_seq = data_processing_object["source_data_seq"]
    training_data_target_seq = data_processing_object['target_data_seq']

    train_dataset_instance = MyDataset((training_data_source_seq, training_data_target_seq))

    # Training DataLoader
    train_dl = DataLoader(
        train_dataset_instance,
        batch_size=h_params_cfg["batch_size"],
        shuffle=True, # Shuffle training data
        drop_last=True, # Important for consistent batch sizes if not perfectly divisible
        num_workers=0, # Set to > 0 for parallel data loading if beneficial and safe
        pin_memory=True if device.type == 'cuda' else False # For faster CPU to GPU transfers
    )

    # --- Validation Data Preparation ---
    # Use the vocabularies learned from the training set
    source_vocab_mgr_val = data_processing_object["_source_vocab_manager"]
    target_vocab_mgr_val = data_processing_object["_target_vocab_manager"]

    # Pad validation strings using max_lengths from training data
    val_padded_s_strings = add_padding(val_source_list, data_processing_object["INPUT_MAX_LENGTH"])
    val_padded_t_strings = add_padding(val_target_list, data_processing_object["OUTPUT_MAX_LENGTH"])

    # Convert validation strings to sequences using training vocabs
    val_source_sequences = generate_string_to_sequence(val_padded_s_strings, source_vocab_mgr_val, device)
    val_target_sequences = generate_string_to_sequence(val_padded_t_strings, target_vocab_mgr_val, device)

    validation_dataset_instance = MyDataset((val_source_sequences, val_target_sequences))

    # Validation DataLoader
    val_dl = DataLoader(
        validation_dataset_instance,
        batch_size=h_params_cfg["batch_size"],
        shuffle=False, # No need to shuffle validation data
        drop_last=True, # Or False, depending on if you need to evaluate on every last sample
        num_workers=0,
        pin_memory=True if device.type == 'cuda' else False
    )

    # Redundant operation for code structure
    # _ = len(train_dl) * len(val_dl)

    return train_dl, val_dl, data_processing_object # Return the main data object for char maps etc.


# --- Hyperparameters and WandB Sweep Configuration ---
# Example h_params (original, commented out as it's for single run/sweep)
# h_params = {
# "char_embd_dim": 256,
# "hidden_layer_neurons": 256,
# "batch_size": 32,
# "number_of_layers": 3,
# "learning_rate": 0.0001,
# "epochs": 20,
# "cell_type": "LSTM",
# "dropout": 0.1,
# "optimizer": "adam"
# }

# Sweep parameters (original structure)
sweep_params = {
    'method' : 'bayes',
    'name'   : 'DL_A3_Highly_Refactored_Sweep_v2', # New sweep name
    'metric' : {
        'goal' : 'maximize',
        'name' : 'val_accuracy',
    },
    'parameters' : { # Slightly adjusted ranges for variety
        'epochs':{'values' : [10, 15, 20]},
        'learning_rate':{'values' : [0.001, 0.0005, 0.0002]}, # Finer control
        'batch_size':{'values':[32, 64]}, # Keeping batch sizes manageable
        'char_embd_dim':{'values' : [128, 256, 384] } ,
        'number_of_layers':{'values' : [1, 2, 3]},
        'optimizer':{'values':['nadam','adam']},
        'cell_type':{'values' : ["LSTM", "GRU"]},
        'hidden_layer_neurons':{'values': [256, 384, 512]},
        'dropout':{'values': [0.1, 0.2, 0.3, 0.0]} # Added 0.0 for no dropout
    }
}

# To initialize a sweep (run this part once, e.g., in a separate script or notebook cell):
SWEEP_ID_GENERATED = wandb.sweep(sweep=sweep_params, project="DA6401_Assignment_3")
print(f"Generated WandB Sweep ID: {SWEEP_ID_GENERATED}")
# Store this SWEEP_ID_GENERATED. For this example, I'll use the one from the prompt.

# Main function for WandB agent
def main(): # Function name kept (for wandb.agent)
    run_config = None # To hold wandb.config
    try:
        # Initialize a new WandB run for this agent's trial
        # project name might be needed if not running via CLI `wandb agent` command that specifies it
        run_instance = wandb.init(project="DL Assignment 3") # Name and config will be set by sweep
        run_config = wandb.config

        # Construct a descriptive run name for better tracking in WandB UI
        # This uses attribute access which is standard for wandb.config
        current_run_name = (
            f"{run_config.cell_type}_{run_config.optimizer}_ep{run_config.epochs}_"
            f"lr{run_config.learning_rate:.0e}_emb{run_config.char_embd_dim}_"
            f"hid{run_config.hidden_layer_neurons}_bs{run_config.batch_size}_"
            f"layers{run_config.number_of_layers}_drop{run_config.dropout:.1f}"
        ).replace("e-0", "e-") # Cleaner scientific notation for LR

        wandb.run.name = current_run_name # Update the run name in WandB
        # wandb.run.save() # Not always necessary, name update usually syncs

        # print(f"--- Starting WandB Run: {current_run_name} ---")
        # print("Hyperparameters for this run:")
        # for key, val in run_config.items():
        #     print(f"  {key}: {val}")

    except Exception as e:
        print(f"Error during WandB initialization or run naming: {e}")
        # If wandb fails, run_config might be None. We might want to exit or use default h_params.
        # For now, assume it will proceed and crash later if run_config is vital and None.
        # Or, define a default_h_params here for fallback.
        if run_config is None: # Critical failure
            print("Cannot proceed without wandb config in sweep mode. Exiting this agent run.")
            return # Exit this agent's attempt

    # Ensure all h_params needed by the functions are present in run_config
    # This acts as a contract check or allows defaults.
    # Example: run_config.setdefault('some_new_param', default_value)

    # Prepare dataloaders using the hyperparameters from wandb.config
    # train_source, train_target, etc., are global from data loading section
    train_dataloader_sweep, val_dataloader_sweep, data_struct_sweep = prepare_dataloaders(
        train_source, train_target,
        val_source, val_target,
        run_config # Pass the wandb config object
    )

    # Execute the training process
    # The 'train' function internally logs metrics to the current WandB run
    train(run_config, data_struct_sweep, device, train_dataloader_sweep, val_dataloader_sweep)

    # print(f"--- Finished WandB Run: {current_run_name} ---")
    # WandB run finishes automatically when the agent function exits or wandb.finish() is called.
    # wandb.finish() # Explicitly finish if desired, though agent usually handles it.

# --- Agent Execution ---
# Replace "YOUR_PROJECT_NAME/YOUR_SWEEP_ID" with the actual ID from wandb.sweep()
# The prompt used "hw3b5jng" which looks like just the sweep ID part.
# Usually it's <entity>/<project>/<sweep_id>
# AGENT_SWEEP_ID = "hw3b5jng" 
PROJECT_NAME_FOR_AGENT = "DA6401_Assignment_3" # Explicitly state project for agent

if __name__ == "__main__":
    # This block allows running a single training session with fixed params if not doing a sweep
    # by commenting out the wandb.agent line.

    # --- For Single Run (manual hyperparameter setting) ---
    # print("Attempting a single, non-sweep run configuration...")
    # fixed_h_params = { # Define your fixed hyperparameters here
    #     "char_embd_dim": 128, "hidden_layer_neurons": 256, "batch_size": 32,
    #     "number_of_layers": 2, "learning_rate": 0.001, "epochs": 5, # Short epochs for test
    #     "cell_type": "GRU", "dropout": 0.1, "optimizer": "adam"
    # }
    # single_run_wandb_name = (
    #         f"SINGLE_{fixed_h_params['cell_type']}_{fixed_h_params['optimizer']}_ep{fixed_h_params['epochs']}_"
    #         f"lr{fixed_h_params['learning_rate']:.0e}_emb{fixed_h_params['char_embd_dim']}"
    # ).replace("e-0", "e-")
    try:
        wandb.init(project=PROJECT_NAME_FOR_AGENT, name=single_run_wandb_name, config=fixed_h_params)
        train_dl_single, val_dl_single, data_s_single = prepare_dataloaders(
            train_source, train_target, val_source, val_target, fixed_h_params
        )
        train(fixed_h_params, data_s_single, device, train_dl_single, val_dl_single)
        wandb.finish()
        print("Single run finished.")
    except Exception as e_single:
        print(f"Error during single run: {e_single}")
    --- End Single Run Block ---

    # --- For WandB Sweep ---
    # To run the sweep, comment out the single run block above and uncomment the agent line.
    print(f"Attempting to start WandB agent for sweep ID: {AGENT_SWEEP_ID} in project: {PROJECT_NAME_FOR_AGENT}")
    try:
        wandb.agent(sweep_id=AGENT_SWEEP_ID, function=main, count=100, project=PROJECT_NAME_FOR_AGENT)
        print("WandB agent process finished or was stopped.")
    except Exception as e_agent:
        print(f"Error starting or running WandB agent: {e_agent}")
        print("Ensure you have run 'wandb login' and the sweep ID is correct for the specified project.")
    # --- End Sweep Block ---