# Group 26 - Traditional Machine Learning Approach
## SVM

#### Harvey Dennis and William Asbery

Follow the instructions in this notebook to successfully carry out predictions on the test file

### Dependencies

__PLEASE RUN THE CELL BELOW__

In [1]:
!pip install scikit-learn nltk optuna plotly





In [2]:
import gc
import json
import logging
import time
import pickle
import psutil
import string
import numpy as np
import optuna
import pandas as pd
from plotly.io import show
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline, FeatureUnion
from sklearn.decomposition import PCA
from pathlib import Path
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from contextlib import contextmanager
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    matthews_corrcoef
)
from collections import Counter
from typing import Dict, List, Tuple, Set, Optional


  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'plotly'

### Config

This config specifies the configuration for the paths of the files being used to train, validate and test the model.
Also, the config contains the characteristics of the augmented file generated from the train data.

Please add the relative or absolute paths to the train, dev, test and augmented train files. We provide you with an augmented train file as the augmentation pipeline takes about 2 hours to run.

Once all the paths are added, 

__PLEASE RUN THE CELL BELOW.__

In [None]:
class BaseConfig:
    DATA_DIR = Path(__file__) / "data"
    TRAIN_FILE = Path("train.csv")
    DEV_FILE = Path("dev.csv")
    TEST_FILE = Path("test.csv")
    AUG_TRAIN_FILE = Path("train_augmented.csv")
    SAVE_DIR = Path(__file__) / "data" / "results"
    CACHE_DIR = Path(__file__) / "cache"

    # Augmentation config
    AUGMENTATION_CONFIG = {
        "0": {
            "replace": 0.0,
            "add": 0.1, # 10%
            "translate":{
                "percentage": 1.0,
                "split": {
                    "Claim": 0.15,
                    "Evidence": 0.7,
                    "Both": 0.15
                },
                "src": "en",
                "intermediates": {
                    "fr": 0.5,
                    "de": 0.4,
                    "ja": 0.1
                }
            },
            "synonym_replacement": {
                "percentage": 0.7,
                "replacement_fraction": 0.3,
                "min_similarity": 0.85,
                "min_word_length": 4,
                "word_frequency_threshold": 3,
                "synonym_selection_strategy": "random",
                "enable_random_synonym_insertion": True,
                "synonym_insertion_probability": 0.03,
                "enable_random_word_insertion": True,
                "word_insertion_probability": 0.01,
                "enable_random_deletion": True,
                "deletion_probability": 0.01,
            },
            "x_or_y": {
                "percentage": 0.08,
                "max_choices": 4,
                "num_words_to_augment": {
                    "Claim": 1,
                    "Evidence": 2
                },
                "split": {
                    "Claim": 0.90,
                    "Evidence": 0.05,
                    "Both": 0.05
                }
            }
        },
        "1": {
            "replace": 0.0,
            "add": 1.0,
            "translate":{
                "percentage": 0.8,
                "split": {
                    "Claim": 0.15,
                    "Evidence": 0.7,
                    "Both": 0.15
                },
                "src": "en",
                "intermediates": {
                    "fr": 0.5,
                    "de": 0.4,
                    "ja": 0.1
                }
            },
            "synonym_replacement": {
                "percentage": 0.7,
                "replacement_fraction": 0.3,
                "min_similarity": 0.85,
                "min_word_length": 4,
                "word_frequency_threshold": 3,
                "synonym_selection_strategy": "random",
                "enable_random_synonym_insertion": True,
                "synonym_insertion_probability": 0.03,
                "enable_random_word_insertion": True,
                "word_insertion_probability": 0.01,
                "enable_random_deletion": True,
                "deletion_probability": 0.01,
            },
            "x_or_y": {
                "percentage": 0.02,
                "max_choices": 4,
                "num_words_to_augment": {
                    "Claim": 1,
                    "Evidence": 2
                },
                "split": {
                    "Claim": 0.90,
                    "Evidence": 0.05,
                    "Both": 0.05
                }
            }
        }
    }
    
def get_config() -> BaseConfig:
    return BaseConfig()


config = get_config()

### Best Params from tuning

In [None]:
# Configure logging
logger = logging.getLogger(__name__)

params = {
    "vocab_size": 12000,
    "n_gram_range": (1, 2),
    "embedding_dim": 300,
    "pca_components": 540,
    "C": 1.96,
    "tfidf_weighting": True,
    "min_df": 1,
    "max_df": 0.95,
    "kernel": 'rbf',
    "gamma": 'scale'
}

### Data preparation

In [None]:
def prepare_svm_data(data: pd.DataFrame, 
                    remove_stopwords: bool = True, 
                    lemmatize: bool = True, 
                    min_freq: int = 2, 
                    vocab_size: Optional[int] = None) -> Tuple[pd.DataFrame, np.ndarray, Set[str]]:
    """
    Prepare text data for SVM training by cleaning, normalizing and vocabulary management.
    
    Args:
        data: DataFrame containing 'Claim' and 'Evidence' columns
        remove_stopwords: Whether to remove common stopwords
        lemmatize: Whether to apply lemmatization
        min_freq: Minimum frequency for words to be included in vocabulary
        vocab_size: Maximum vocabulary size (most frequent words kept)
    
    Returns:
        Tuple containing:
            - Processed DataFrame with added 'text' column
            - NumPy array of labels
            - Set of vocabulary words
    """
    translator = str.maketrans('', '', string.punctuation)

    def clean_text(text: str) -> str:
        """
        Clean and normalize text by lowercasing, removing punctuation,
        and optionally removing stopwords and lemmatizing.
        """
        text = text.lower().translate(translator)
        # Normalize whitespace
        text = " ".join(text.split())
        
        if remove_stopwords:
            try:
                # Keep important discourse markers and modal verbs
                keep_words = {
                    'because', 'since', 'therefore', 'hence', 'thus', 'although',
                    'however', 'but', 'not', 'should', 'must', 'might', 'may',
                    'could', 'would', 'against', 'between', 'before', 'after'
                }
                custom_stopwords = set(stopwords.words("english")) - keep_words
                
                text = " ".join([word for word in text.split() 
                               if word not in custom_stopwords])
            except Exception:
                pass
            
        if lemmatize:
            try:
                lemmatizer = WordNetLemmatizer()
                words = text.split()
                text = " ".join([lemmatizer.lemmatize(word) for word in words])
            except Exception:
                pass
        return text

    # Build vocabulary from training data
    train_samples = pd.concat([data['Claim'], data['Evidence']]).apply(clean_text)
    all_words = [word for text in train_samples for word in text.split()]
    word_counts = Counter(all_words)

    # Filter words by minimum frequency and sort by frequency
    filtered_words = [(word, count) for word, count in word_counts.items() if count >= min_freq]
    sorted_words = sorted(filtered_words, key=lambda x: (-x[1], x[0]))
    
    # Apply vocabulary size limit if specified
    if vocab_size is not None:
        sorted_words = sorted_words[:vocab_size]
    
    vocab = {word for word, _ in sorted_words}

    def replace_rare_words(text: str) -> str:
        """Replace words not in vocabulary with <UNK> token."""
        return ' '.join([word if word in vocab else '<UNK>' for word in text.split()])

    # Process the data with UNK replacement
    data['text'] = ("Claim: " + data['Claim'].apply(clean_text).apply(replace_rare_words) + 
                    " [SEP] " + "Evidence: " + data['Evidence'].apply(clean_text).apply(replace_rare_words))

    # Extract labels
    labels = data['label'].values

    return data, labels, vocab

### Prediction Function

In [None]:
def predict_with_saved_model(
    pipeline_path: Path, 
    input_csv_path: Path, 
    output_csv_path: Path
) -> None:
    """
    Loads a saved SVM pipeline, makes predictions on data from an input CSV, 
    and saves the predictions to an output CSV.

    Args:
        pipeline_path: Path to the saved .pkl pipeline file.
        input_csv_path: Path to the input CSV file (must contain 'Evidence' column).
        output_csv_path: Path where the predictions CSV will be saved.
    """
    logger.info("\n" + "="*70)
    logger.info(f"MAKING PREDICTIONS FROM {input_csv_path}")
    logger.info("="*70)

    # --- Input Validation ---
    if not pipeline_path.exists():
        logger.error(f"Pipeline file not found at {pipeline_path}. Cannot make predictions.")
        return
    if not input_csv_path.exists():
        logger.error(f"Input CSV file not found at {input_csv_path}. Cannot make predictions.")
        return
    
    # Ensure output directory exists
    output_csv_path.parent.mkdir(parents=True, exist_ok=True)

    try:
        # --- Load Pipeline --- 
        with open(pipeline_path, "rb") as f:
            loaded_pipeline = pickle.load(f)
        logger.info(f"Pipeline loaded successfully from {pipeline_path}")

        # --- Load and Prepare Input Data ---
        input_df = pd.read_csv(input_csv_path)
        logger.info(f"Loaded {len(input_df)} rows from {input_csv_path}")

        if 'Evidence' not in input_df.columns or 'Claim' not in input_df.columns:
            logger.error(f"Input CSV {input_csv_path} must contain 'Evidence' and 'Claim' columns.")
            return
            

        # Determine training parameters needed for preprocessing
        training_vocab_size = params.get('vocab_size', 12000) 
        logger.info(f"Using parameters for preprocessing: vocab_size={training_vocab_size}")


        # Apply the *exact same* preprocessing as used during training
        processed_data_df, _, _ = prepare_svm_data(
            input_df, 
            remove_stopwords=True,
            lemmatize=True,        
            min_freq=2, 
            vocab_size=training_vocab_size
        )
        processed_texts = processed_data_df['text'].tolist()
        logger.info(f"Preprocessing complete for {len(processed_texts)} texts.")

        # --- Make Predictions --- 
        predictions = loaded_pipeline.predict(processed_texts)
        logger.info(f"Generated {len(predictions)} predictions.")

        # --- Save Predictions --- 
        predictions_df = pd.DataFrame({'prediction': predictions})
        predictions_df.to_csv(output_csv_path, index=False)
        logger.info(f"Predictions saved successfully to {output_csv_path}")

    except ModuleNotFoundError as e:
         logger.error(f"Error loading pickle: A module required by the pickled object was not found: {e}")
         logger.error("Ensure all necessary libraries and custom classes (GloveVectorizer, etc.) are importable.")
    except FileNotFoundError as e:
        logger.error(f"Error: A required file was not found: {e}")
    except KeyError as e:
        logger.error(f"Error: Missing expected column in input data: {e}")
    except Exception as e:
        logger.error(f"An error occurred during prediction: {e}", exc_info=True)


### Prediction Generation

In [None]:
pipeline_pickle_path = ""

try:
    prediction_input_file = config.DEV_FILE
    prediction_output_file = config.DATA_DIR / "svm_predictions.csv"
    
    # Ensure the predictions directory exists
    prediction_output_file.parent.mkdir(parents=True, exist_ok=True)
    
    predict_with_saved_model(
        pipeline_path=pipeline_pickle_path,
        input_csv_path=prediction_input_file, 
        output_csv_path=prediction_output_file
    )
except Exception as e:
    logger.error(f"Error predicting with saved model: {e}", exc_info=True)

### The code cells below show how we Augmented the data, trained the model and evaluated the final model
### These cells don't need to be run (and probably wont work if run, as we didn't use notebooks to complete the svm). They are to show the process of the svm creation

## Augmentation

### Back Translation

This code performs back translation to create paraphrased versions of text data by translating it to another language and back. This augments the dataset, helping the model by increasing the variety of training examples, improving generalisation and robustness, especially when labeled data is limited.

In [None]:
"""
Back translation module for data augmentation.

This module implements back translation, a technique that translates text to an
intermediate language and then back to the original language to create paraphrased
versions of the original text, used for data augmentation.
"""
import asyncio
from googletrans import Translator


async def back_translate_batch(data: pd.DataFrame, column: str, src='en', intermediate='fr') -> pd.DataFrame:
    """
    Apply back-translation to a batch of text in a DataFrame.
    
    Args:
        data: DataFrame containing the text to translate
        column: Column name to translate (or "Both" for "Claim" and "Evidence")
        src: Source language code
        intermediate: Intermediate language code
        
    Returns:
        DataFrame with translated text
    """
    async with Translator() as translator:
        async def translate_text(text):
            """
            Translate a single text with retry mechanism.
            
            Makes up to 3 attempts to translate the text, waiting 0.5s between attempts.
            Returns original text if all attempts fail.
            """
            for attempt in range(3):
                try:
                    translation = await translator.translate(text, src=src, dest=intermediate)
                    back_to_source = await translator.translate(translation.text, src=intermediate, dest=src)
                    return back_to_source.text
                except Exception as e:
                    logging.warning(f"Attempt {attempt + 1} failed: {e}")
                    await asyncio.sleep(0.5)
            return text
        
        if column == "Both":
            data.loc[:, "Claim"] = [await translate_text(text) for text in data["Claim"]]
            data.loc[:, "Evidence"] = [await translate_text(text) for text in data["Evidence"]]
        else:
            data.loc[:, column] = [await translate_text(text) for text in data[column]]

        return data


### Synonym Replacement, Addition, Deletion

The code is able to perform synonym replacement while maintaining the text's semantic meaning. It can also perferm random synonym addition and random word deletion. These methods increase the variety of the training data and should help with generalisation.

Warning: it's quite a lot of code

In [None]:
# Standard library imports
import logging
import random
import re
from collections import defaultdict, Counter
from pathlib import Path

# Third-party imports
import nltk
import numpy as np
import pandas as pd
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util
from nltk.corpus import wordnet, stopwords

# Suppress warnings
import warnings
warnings.filterwarnings('ignore')

# Download required NLTK resources
nltk.download('averaged_perceptron_tagger_eng')
nltk.download('wordnet')
nltk.download('omw-1.4')
nltk.download('stopwords')
nltk.download('punkt_tab')

class AdvancedSynonymReplacer:
    """
    Advanced data augmentation class that performs synonym replacement with semantic similarity control.
    
    This class enhances text data by replacing words with their synonyms while maintaining
    semantic meaning of the original text. It uses Sentence Transformers to verify 
    that augmented sentences remain semantically similar to the original.
    
    Attributes:
        params (dict): Configuration parameters for augmentation.
        train_df (pd.DataFrame): Training data containing 'Evidence', 'Claim', and 'label' columns.
        device (str): Computing device ('cpu' or 'cuda').
        st_model (SentenceTransformer): Model for semantic similarity measurement.
        stop_words (set): Set of English stop words to ignore during augmentation.
        word_frequencies (Counter): Word frequency counter from training data.
    """

    def __init__(self, params: dict, train_df: pd.DataFrame):
        """
        Initialize the synonym replacement augmentation with specified parameters.
        
        Args:
            params (dict): Configuration parameters for the augmentation process.
            train_df (pd.DataFrame): Training data with 'Evidence', 'Claim', and 'label' columns.
        """
        self.params = params
        self.device = get_device()
        self.stop_words = set(stopwords.words('english'))

        # Load Sentence Transformer Model
        self.st_model_name = params.get('sentence_transformer_model', 'sentence-transformers/all-MiniLM-L6-v2')
        logging.info(f"Loading Sentence Transformer model: {self.st_model_name} onto device: {self.device}")
        self.st_model = SentenceTransformer(self.st_model_name, device=self.device)
        logging.info("Sentence Transformer model loaded.")

        # Configuration parameters
        self.min_sentence_similarity = params.get("min_sentence_similarity", 0.85)
        self.replacement_fraction = params.get("replacement_fraction", 0.5)
        self.batch_size = params.get("batch_size", 1000)
        self.add_original_evidence_to_results = params.get("add_original_evidence_to_results", True)
        self.results_file_name = params.get("output_file", config.DATA_DIR / "advanced_synonym_replacement_results.csv")
        self.min_word_length = params.get("min_word_length", 4)
        self.synonym_selection_strategy = params.get("synonym_selection_strategy", "random")
        self.allow_multi_word_synonyms = params.get("allow_multi_word_synonyms", False)
        self.word_frequency_threshold = params.get("word_frequency_threshold", 5)
        
        # Advanced augmentation settings
        self.enable_random_synonym_insertion = params.get("enable_random_synonym_insertion", False)
        self.synonym_insertion_probability = params.get("synonym_insertion_probability", 0.05)
        
        self.enable_random_word_insertion = params.get("enable_random_word_insertion", False)
        self.word_insertion_probability = params.get("word_insertion_probability", 0.05)
        
        self.enable_random_deletion = params.get("enable_random_synonym_deletion", False)
        self.deletion_probability = params.get("deletion_probability", 0.05)

        # Store original DataFrame and prepare data
        self.train_df = train_df.copy()
        self._prepare_data()

        logging.info("Starting advanced data augmentation with the following parameters:")
        for key, value in params.items():
            logging.info(f" - {key}: {value}")


    def _prepare_data(self):
        """
        Prepare the training data by adding POS tags and calculating word frequencies.
        
        This method tokenizes evidence sentences, adds part-of-speech tags, and
        builds a word frequency dictionary for later use in the augmentation process.
        """
        if 'POS' not in self.train_df.columns:
            self.train_df['POS_Evidence'] = self.train_df['Evidence'].apply(
                lambda x: nltk.pos_tag(nltk.word_tokenize(x))
            )

        self.original_evidences_pos = self.train_df['POS_Evidence'].tolist()
        self.original_evidences = self.train_df['Evidence'].tolist()
        self.preprocessed_evidences = self.train_df['Evidence'].apply(remove_stopwords).tolist()
        self.corresponding_claim = self.train_df['Claim'].apply(remove_stopwords).tolist()

        # Calculate word frequencies
        all_words = []
        for text in self.train_df['Evidence']:
            all_words.extend(nltk.word_tokenize(text.lower()))
        self.word_frequencies = Counter(all_words)


    def calculate_sentence_similarity(self, sentence_1: str, sentence_2: str) -> float:
        """
        Calculate the semantic similarity between two sentences.
        
        Uses the Sentence Transformer model to generate embeddings and compute
        the cosine similarity between them.
        
        Args:
            sentence_1: First sentence to compare
            sentence_2: Second sentence to compare
            
        Returns:
            float: Similarity score between 0 and 1, where 1 indicates identical meaning
        """
        embeddings = self.st_model.encode([sentence_1, sentence_2], convert_to_tensor=True, device=self.device, verbose=False)
        cosine_scores = util.cos_sim(embeddings[0], embeddings[1])
        return cosine_scores.item()


    def _get_wordnet_pos(self, tag):
        """
        Map NLTK POS tags to WordNet POS tags.
        
        Args:
            tag: NLTK part-of-speech tag
            
        Returns:
            WordNet POS constant or None if no matching tag found
        """
        if tag.startswith('J'):
            return wordnet.ADJ
        elif tag.startswith('V'):
            return wordnet.VERB
        elif tag.startswith('N'):
            return wordnet.NOUN
        elif tag.startswith('R'):
            return wordnet.ADV
        else:
            return None


    def _process_text(
        self,
        text_tokens: list[str],
        pos_tags: list[tuple[str, str]],
        claim_words: set = None,
        is_claim: bool = False
    ) -> list[str]:
        """
        Process text to identify candidate words eligible for replacement.
        
        Analyzes the tokens and their POS tags to determine which words can be
        replaced with synonyms based on various criteria such as word length,
        frequency, and part of speech.
        
        Args:
            text_tokens: List of tokenized words from the text
            pos_tags: List of (word, POS tag) tuples
            claim_words: Set of words in the claim (if processing evidence)
            is_claim: Whether the text being processed is a claim
            
        Returns:
            List of words eligible for synonym replacement
        """
        potential_replacements = []
        safe_pos_tags = {
            'NN', 'NNS', 'JJ', 'JJR', 'JJS',
            'RB', 'RBR', 'RBS',
            'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ'
        }

        pos_tags_dict = defaultdict(list)
        for word, tag in pos_tags:
            pos_tags_dict[word.lower()].append(tag)

        common_words = set()
        if claim_words is not None and not is_claim:
            common_words = set(text_tokens) & claim_words

        for word in text_tokens:
            lower_word = word.lower()

            # Skip if stop word
            if lower_word in self.stop_words:
                continue

            # Skip based on frequency
            if self.word_frequencies.get(lower_word, 0) < self.word_frequency_threshold:
                continue

            # Skip if word is in both claim and evidence (if applicable and not processing claim)
            # or if the word is a substring of any claim word (or vice versa)
            if (not is_claim and (word in common_words or
                    any(word in cw or cw in word for cw in claim_words))):
                continue
                
            # Skip if word not in pos_tags_dict
            if lower_word not in pos_tags_dict:
                continue
                
            # Skip if word is too short
            if len(word) < self.min_word_length:
                continue
                
            # Skip if the word's POS tags are not in the safe list
            if not any(tag in safe_pos_tags for tag in pos_tags_dict[lower_word]):
                continue
                
            potential_replacements.append(word)

        return potential_replacements


    def get_synonyms(self, word: str, pos_tag: str = None, topn: int = 10) -> list[str]:
        """
        Retrieve synonyms for a word using WordNet.
        
        Finds synonyms that match the part of speech of the original word
        and applies filtering based on configuration parameters.
        
        Args:
            word: Target word to find synonyms for
            pos_tag: Part-of-speech tag of the word
            topn: Maximum number of synonyms to return
            
        Returns:
            List of potential synonyms (empty if none found)
        """
        synonyms = set()
        wordnet_pos = self._get_wordnet_pos(pos_tag) if pos_tag else None

        synsets = wordnet.synsets(word, pos=wordnet_pos)
        if not synsets:
            return []

        for syn in synsets:
            for lemma in syn.lemmas():
                synonym = lemma.name().replace('_', ' ')
                # Only add if not the same as original word
                if synonym.lower() != word.lower():
                    # Filter multi-word synonyms based on configuration
                    if self.allow_multi_word_synonyms or ' ' not in synonym:
                        synonyms.add(synonym)
                        
                    if len(synonyms) >= topn:
                        break
                    
            if len(synonyms) >= topn:
                break

        synonym_list = list(synonyms)
        
        # Apply synonym selection strategy
        if self.synonym_selection_strategy == 'frequent':
            # Prioritize synonyms that appear more frequently in the training data
            synonym_list.sort(key=lambda s: self.word_frequencies.get(s.lower(), 0), reverse=True)
        elif self.synonym_selection_strategy == 'random':
            random.shuffle(synonym_list)

        return synonym_list[:topn]

    def get_random_word(self) -> str:
        """
        Get a random word from the training data vocabulary.
        
        Returns:
            A random word from the training data vocabulary
        """
        return random.choice(list(self.word_frequencies.keys()))


    def find_valid_replacements(
        self,
        word_to_replace: str,
        synonyms: list[str],
        original_text: str,
        original_pos_tags: dict
    ) -> tuple[bool, str]:
        """
        Find a valid synonym replacement that maintains semantic similarity.
        
        Tests each synonym to see if it:
        1. Maintains the same part of speech
        2. Keeps the sentence semantically similar to the original
        
        Args:
            word_to_replace: The original word to be replaced
            synonyms: List of potential synonym candidates
            original_text: The complete original text
            original_pos_tags: Dictionary mapping words to their POS tags
            
        Returns:
            Tuple of (success_flag, replacement_word)
        """
        lower_word = word_to_replace.lower()
        original_word_pos_tags = original_pos_tags.get(lower_word, [])
        
        if not original_word_pos_tags:
            return False, ""
            
        primary_pos = original_word_pos_tags[0]

        for synonym in synonyms:
            # Create pattern to replace only whole words (not substrings)
            pattern = r'\b' + re.escape(word_to_replace) + r'\b'
            
            try:
                new_text = re.sub(pattern, synonym, original_text, flags=re.IGNORECASE)
            except re.error:
                logging.warning(f"Regex error replacing '{word_to_replace}' with '{synonym}'. Skipping.")
                continue

            # Skip if no replacement was made
            if new_text == original_text:
                continue

            # Check if the POS tag is preserved
            new_text_pos = nltk.pos_tag(nltk.word_tokenize(new_text))
            synonym_pos_tags = [tag for (w, tag) in new_text_pos if w.lower() == synonym.lower()]
            
            if not synonym_pos_tags:
                continue

            # Verify that the WordNet POS category is maintained
            if self._get_wordnet_pos(synonym_pos_tags[0]) != self._get_wordnet_pos(primary_pos):
                continue

            # Check semantic similarity to ensure meaning is preserved
            similarity = self.calculate_sentence_similarity(original_text, new_text)
            if similarity >= self.min_sentence_similarity:
                return True, synonym

        return False, ""


    def _random_insertion(self, tokens: list[str], pos_tags: list[tuple[str, str]], add_a_synonym: bool = True) -> list[str]:
        """
        Randomly insert a word or synonym into the token list.
        
        Args:
            tokens: List of tokens to augment
            pos_tags: List of (word, POS tag) tuples
            add_a_synonym: Whether to insert a synonym (True) or random word (False)
            
        Returns:
            Augmented list of tokens with the insertion
        """
        augmented_tokens = list(tokens)
        
        if not augmented_tokens:
            return []
        
        # Choose random position and word
        insert_index = random.randint(0, len(augmented_tokens))
        word_to_augment = random.choice(augmented_tokens)
        lower_word = word_to_augment.lower()
        
        # Get POS tag for the word
        word_pos_dict = dict(pos_tags)
        original_word_pos_tag = word_pos_dict.get(lower_word)
        
        if not original_word_pos_tag:
            return augmented_tokens
        
        # Get either a synonym or random word
        if add_a_synonym:
            candidates = self.get_synonyms(word_to_augment, original_word_pos_tag, topn=5)
        else:
            candidates = [self.get_random_word()]
            
        if not candidates:
            return augmented_tokens
        
        # Insert the selected word
        word_to_insert = random.choice(candidates)
        augmented_tokens.insert(insert_index, word_to_insert)
                    
        return augmented_tokens


    def _random_deletion(self, tokens: list[str]) -> list[str]:
        """
        Randomly delete words from the token list.
        
        Deletes up to 10% of words from the input tokens.
        
        Args:
            tokens: List of tokens to augment
            
        Returns:
            Augmented list of tokens with deletions
        """
        if not tokens:
            return []
        
        # Create a copy of the tokens list
        augmented_tokens = tokens.copy()
        
        # Choose random tokens to delete (up to 10%)
        max_deletions = max(1, int(len(augmented_tokens) * 0.1))
        num_to_delete = random.randint(1, max_deletions)
        indices_to_delete = random.sample(range(len(augmented_tokens)), num_to_delete)
        
        # Create new list without the deleted tokens
        augmented_tokens = [token for i, token in enumerate(augmented_tokens) 
                           if i not in indices_to_delete]
        
        return augmented_tokens


    def augment_data(self):
        """
        Perform data augmentation on the training dataset.
        
        This method applies synonym replacement, insertion, and deletion operations
        to generate augmented versions of the evidence texts. The augmentation
        preserves the semantic similarity above the specified threshold.
        
        The augmented data is saved in batches to the specified output file.
        
        Returns:
            int: Number of successful augmentations performed
        """
        # Check if results file exists and confirm overwrite
        results_path = Path(self.results_file_name)
        if results_path.exists():
            overwrite = input(
                f"Results file {self.results_file_name} already exists. Overwrite? (y/n) "
            ).strip().lower()
            if overwrite != 'y':
                logging.info("Augmentation aborted by user.")
                return 0
        else:
            results_path.parent.mkdir(parents=True, exist_ok=True)

        # Prepare output dataframe columns
        cols = ["Claim", "Evidence", "label"]
        if self.add_original_evidence_to_results:
            cols.append("Original Evidence")
            cols.append("Similarity Score")

        # Initialize tracking variables
        synonym_replaced_df = pd.DataFrame(columns=cols)
        batch_counter = 0
        successful_augmentations = 0
        attempted_augmentations = 0
        total_words_augmented = 0

        # Get data from original dataframe
        original_claims = self.train_df['Claim'].tolist()
        labels = self.train_df['label'].tolist()

        # Process each evidence text
        for idx, original_evidence_text in tqdm(
            enumerate(self.original_evidences),
            desc="Augmenting data",
            total=len(self.original_evidences)
        ):
            attempted_augmentations += 1
            original_claim_text = original_claims[idx]

            # Get POS tagging for the evidence
            evidence_pos_tags = self.original_evidences_pos[idx]
            evidence_pos_tags_dict = defaultdict(list)
            for word, tag in evidence_pos_tags:
                evidence_pos_tags_dict[word.lower()].append(tag)
            evidence_tokens = nltk.word_tokenize(original_evidence_text)

            # Start with original tokens
            augmented_evidence_tokens = list(evidence_tokens)
            
            # Apply random augmentations based on probabilities
            # 1. Insert synonyms
            should_insert_synonyms = random.random() < self.synonym_insertion_probability
            if self.enable_random_synonym_insertion and should_insert_synonyms:
                augmented_evidence_tokens = self._random_insertion(
                    augmented_evidence_tokens, evidence_pos_tags, add_a_synonym=True)
            
            # 2. Insert random words    
            should_insert_words = random.random() < self.word_insertion_probability
            if self.enable_random_word_insertion and should_insert_words:
                augmented_evidence_tokens = self._random_insertion(
                    augmented_evidence_tokens, evidence_pos_tags, add_a_synonym=False)
            
            # 3. Delete words    
            should_delete_words = random.random() < self.deletion_probability
            if self.enable_random_deletion and should_delete_words:
                augmented_evidence_tokens = self._random_deletion(augmented_evidence_tokens)

            # Find candidate words for synonym replacement
            potential_evidence_replacements = self._process_text(
                augmented_evidence_tokens,
                evidence_pos_tags,
                claim_words=set()
            )

            # Perform the synonym replacements
            num_evidence_replacements = max(0, int(len(potential_evidence_replacements) * self.replacement_fraction))
            
            if potential_evidence_replacements and num_evidence_replacements > 0:
                words_to_replace = random.sample(potential_evidence_replacements, k=num_evidence_replacements)
                current_evidence = " ".join(augmented_evidence_tokens)
                final_word_replacement_map_evidence = {}

                for word in words_to_replace:
                    lower_word = word.lower()
                    if lower_word not in evidence_pos_tags_dict or not evidence_pos_tags_dict[lower_word]:
                        continue

                    word_pos_tag = evidence_pos_tags_dict[lower_word][0]
                    synonyms = self.get_synonyms(word, word_pos_tag, topn=10)
                    if not synonyms:
                        continue

                    found, replacement = self.find_valid_replacements(
                        word,
                        synonyms,
                        current_evidence,
                        evidence_pos_tags_dict
                    )
                    if found:
                        pattern = r'\b' + re.escape(word) + r'\b'
                        try:
                            current_evidence = re.sub(pattern, replacement, current_evidence, flags=re.IGNORECASE)
                            final_word_replacement_map_evidence[word] = replacement
                            total_words_augmented += 1
                        except re.error:
                            logging.warning(f"Regex error applying replacement for '{word}' with '{replacement}'.")
                
                augmented_evidence_text = current_evidence
            else:
                augmented_evidence_text = " ".join(augmented_evidence_tokens)
                
            # Clean up text formatting and spacing
            augmented_evidence_text = self._clean_text_formatting(augmented_evidence_text)

            # Validate the final augmented evidence
            final_similarity_score = self.calculate_sentence_similarity(original_evidence_text, augmented_evidence_text)
            if final_similarity_score >= self.min_sentence_similarity:
                # Add the augmented example to the results
                new_row_data = {
                    "Claim": original_claim_text,
                    "Evidence": augmented_evidence_text,
                    "label": labels[idx]
                }
                
                if self.add_original_evidence_to_results:
                    new_row_data["Original Evidence"] = original_evidence_text
                    new_row_data["Similarity Score"] = final_similarity_score

                synonym_replaced_df = pd.concat([
                    synonym_replaced_df,
                    pd.DataFrame([new_row_data])
                ], ignore_index=True)
                successful_augmentations += 1

            # Save batch if it reaches batch size
            if len(synonym_replaced_df) >= self.batch_size:
                self._save_batch(synonym_replaced_df, batch_counter)
                synonym_replaced_df = pd.DataFrame(columns=cols)
                batch_counter += 1

        # Save any remaining augmented data
        if not synonym_replaced_df.empty:
            self._save_batch(synonym_replaced_df, batch_counter)

        # Log final statistics
        self._log_augmentation_stats(successful_augmentations, attempted_augmentations, total_words_augmented)

        return successful_augmentations
        
    def _clean_text_formatting(self, text: str) -> str:
        """
        Clean up text spacing and punctuation.
        
        Args:
            text: Text to clean
            
        Returns:
            Cleaned text with proper spacing and punctuation
        """
        # Trim whitespace
        cleaned_text = text.strip()
        
        # Ensure symbols like '$' and '£' have a space before them if not at the start
        cleaned_text = re.sub(r'(?<!\s)([$£])', r' \1', cleaned_text)
        
        # Ensure hyphens are attached with no spaces around them
        cleaned_text = re.sub(r'\s*-\s*', '-', cleaned_text)
        
        # Remove extra spaces before punctuation
        cleaned_text = re.sub(r'\s+([,.;?!])', r'\1', cleaned_text)
        
        # Adjust spacing around quotes
        cleaned_text = re.sub(r'(?<!\s)(["\"])', r' \1', cleaned_text)
        cleaned_text = re.sub(r'(["\"])(\s+)', r'\1', cleaned_text)
        
        # Collapse multiple spaces into one
        cleaned_text = re.sub(r'\s{2,}', ' ', cleaned_text)
        
        return cleaned_text
        
    def _save_batch(self, df: pd.DataFrame, batch_counter: int):
        """
        Save a batch of augmented data to the output file.
        
        Args:
            df: DataFrame containing the batch data
            batch_counter: Current batch number
        """
        mode = 'w' if batch_counter == 0 else 'a'
        header = (batch_counter == 0)
        logging.info(f"Saving batch {batch_counter} to {self.results_file_name}")
        df.to_csv(self.results_file_name, index=False, mode=mode, header=header)
        
    def _log_augmentation_stats(self, successful: int, attempted: int, words_augmented: int):
        """
        Log statistics about the augmentation process.
        
        Args:
            successful: Number of successful augmentations
            attempted: Number of attempted augmentations
            words_augmented: Total number of words augmented
        """
        logging.info(f"Augmentation completed. {successful} sentences successfully augmented "
                     f"out of {attempted} attempts.")
        logging.info(f"Total words augmented: {words_augmented}")
        
        if attempted > 0:
            success_rate = (successful / attempted) * 100
            logging.info(f"Success rate: {success_rate:.2f}%")
        else:
            logging.info("No augmentation attempts were made.")


class AdvancedSynonymReplacerDF(AdvancedSynonymReplacer):
    """
    In-place synonym replacement for DataFrame augmentation.
    
    This class extends AdvancedSynonymReplacer to modify the input DataFrame directly
    without creating a separate output file. Useful for integration with transformer
    model pipelines where the entire DataFrame needs to be processed in-memory.
    
    The original evidence texts are replaced with their augmented versions
    while maintaining the same DataFrame structure.
    """
    
    def __init__(self, params: dict, train_df: pd.DataFrame):
        """
        Initialize the in-place DataFrame augmenter.
        
        Args:
            params: Configuration parameters for the augmentation process
            train_df: Training data with 'Evidence', 'Claim', and 'label' columns
        """
        super().__init__(params, train_df)
        self.train_df = train_df
        self._prepare_data()
        
    def _prepare_data(self):
        """
        Prepare the training data by adding POS tags and calculating word frequencies.
        
        Unlike the parent class, this implementation doesn't remove stopwords
        since the full text is directly modified.
        """
        if 'POS' not in self.train_df.columns:
            self.train_df['POS_Evidence'] = self.train_df['Evidence'].apply(
                lambda x: nltk.pos_tag(nltk.word_tokenize(x))
            )

        # Calculate word frequencies
        all_words = []
        for text in self.train_df['Evidence']:
            all_words.extend(nltk.word_tokenize(text.lower()))
        self.word_frequencies = Counter(all_words)
    
    def augment_data(self):
        """
        Perform in-place data augmentation on the DataFrame.
        
        Modifies the 'Evidence' column of the input DataFrame directly with
        augmented versions of the texts.
        
        Returns:
            pd.DataFrame: Reference to the modified input DataFrame
        """
        successful_augmentations = 0
        attempted_augmentations = 0
        total_words_augmented = 0

        # Iterate through each row in the DataFrame
        for idx, row in tqdm(
            self.train_df.iterrows(),
            desc="Augmenting data",
            total=len(self.train_df)
        ):
            attempted_augmentations += 1
            original_evidence_text = row['Evidence']
            original_claim_text = row['Claim']

            # Get POS tagging for the evidence
            evidence_pos_tags = row['POS_Evidence']
            evidence_pos_tags_dict = defaultdict(list)
            for word, tag in evidence_pos_tags:
                evidence_pos_tags_dict[word.lower()].append(tag)
            evidence_tokens = nltk.word_tokenize(original_evidence_text)

            # Start with original tokens
            augmented_evidence_tokens = list(evidence_tokens)
            
            # Apply random augmentations based on probabilities
            # 1. Insert synonyms
            should_insert_synonyms = random.random() < self.synonym_insertion_probability
            if self.enable_random_synonym_insertion and should_insert_synonyms:
                augmented_evidence_tokens = self._random_insertion(
                    augmented_evidence_tokens, evidence_pos_tags, add_a_synonym=True)
            
            # 2. Insert random words    
            should_insert_words = random.random() < self.word_insertion_probability
            if self.enable_random_word_insertion and should_insert_words:
                augmented_evidence_tokens = self._random_insertion(
                    augmented_evidence_tokens, evidence_pos_tags, add_a_synonym=False)
            
            # 3. Delete words    
            should_delete_words = random.random() < self.deletion_probability
            if self.enable_random_deletion and should_delete_words:
                augmented_evidence_tokens = self._random_deletion(augmented_evidence_tokens)

            # Find candidate words for synonym replacement
            potential_evidence_replacements = self._process_text(
                augmented_evidence_tokens,
                evidence_pos_tags,
                claim_words=set()
            )

            # Perform the synonym replacements
            num_evidence_replacements = max(0, int(len(potential_evidence_replacements) * self.replacement_fraction))
            
            if potential_evidence_replacements and num_evidence_replacements > 0:
                words_to_replace = random.sample(potential_evidence_replacements, k=num_evidence_replacements)
                current_evidence = " ".join(augmented_evidence_tokens)
                final_word_replacement_map_evidence = {}

                for word in words_to_replace:
                    lower_word = word.lower()
                    if lower_word not in evidence_pos_tags_dict or not evidence_pos_tags_dict[lower_word]:
                        continue

                    word_pos_tag = evidence_pos_tags_dict[lower_word][0]
                    synonyms = self.get_synonyms(word, word_pos_tag, topn=10)
                    if not synonyms:
                        continue

                    found, replacement = self.find_valid_replacements(
                        word,
                        synonyms,
                        current_evidence,
                        evidence_pos_tags_dict
                    )
                    if found:
                        pattern = r'\b' + re.escape(word) + r'\b'
                        try:
                            current_evidence = re.sub(pattern, replacement, current_evidence, flags=re.IGNORECASE)
                            final_word_replacement_map_evidence[word] = replacement
                            total_words_augmented += 1
                        except re.error:
                            logging.warning(f"Regex error applying replacement for '{word}' with '{replacement}'.")
                
                augmented_evidence_text = current_evidence
            else:
                augmented_evidence_text = " ".join(augmented_evidence_tokens)

            # Clean up text formatting and spacing
            augmented_evidence_text = self._clean_text_formatting(augmented_evidence_text)

            # Validate the final augmented evidence
            final_similarity_score = self.calculate_sentence_similarity(original_evidence_text, augmented_evidence_text)
            if final_similarity_score >= self.min_sentence_similarity:
                # Update the Evidence text directly in the input DataFrame
                self.train_df.at[idx, 'Evidence'] = augmented_evidence_text
                successful_augmentations += 1

        # Log final statistics
        self._log_augmentation_stats(successful_augmentations, attempted_augmentations, total_words_augmented)

        return self.train_df


### X or Y augmentation

This augmenter finds candidate words in a text and replaces them with a format like "word/synonym1/synonym2", creating variations of the original text.

In [None]:
class XorYAugmenter:
    """
    A text augmentation class that replaces words with alternatives in X/Y format.
    
    This augmenter finds candidate words in a text and replaces them with a format
    like "word/synonym1/synonym2", creating variations of the original text.
    """
    
    def __init__(self, train_df: pd.DataFrame, similarity_threshold: float = 0.6, 
                 max_choices: int = 2, num_words_to_augment: int = 1):
        """
        Initialize the XorYAugmenter.
        
        Args:
            train_df: Training dataframe containing text to analyze
            similarity_threshold: Threshold for word similarity (0.0-1.0)
            max_choices: Maximum number of alternative words to include
            num_words_to_augment: Number of words to augment in each text
        """
        self.train_df = train_df
        self.similarity_threshold = similarity_threshold
        self.max_choices = max_choices
        self.num_words_to_augment = num_words_to_augment
        
        self.stop_words = set(stopwords.words('english'))
        self.glove_embeddings = glove_embeddings
    
    def _find_candidates(self, claim: str) -> list[tuple[str, str]]:
        """
        Find candidate words for augmentation in the given text.
        
        Args:
            claim: The text to analyze for augmentation candidates
            
        Returns:
            List of (word, POS tag) tuples that are candidates for augmentation
        """
        tokens = nltk.word_tokenize(claim)
        pos = nltk.pos_tag(tokens)
        
        candidates = []
        
        for word, tag in pos:
            # Skip stopwords and words not in our embedding vocabulary
            if word.lower() in self.stop_words or word.lower() not in self.glove_embeddings:
                continue
            
            candidates.append((word, tag))
            
        return candidates
                
    def _get_wordnet_pos(self, nltk_tag: str) -> str:
        """
        Map NLTK POS tags to WordNet POS tags.
        
        Args:
            nltk_tag: POS tag from NLTK tagger
            
        Returns:
            Corresponding WordNet POS tag
        """
        tag_map = {
            'JJ': wordnet.ADJ,
            'NN': wordnet.NOUN,
            'VB': wordnet.VERB,
            'RB': wordnet.ADV,
            'MD': wordnet.VERB
        }
        return tag_map.get(nltk_tag[:2], wordnet.NOUN)
        
    def _get_similar_words(self, word: str, pos_tag: str = None) -> list[str]:
        """
        Find similar words using WordNet.
        
        Args:
            word: The word to find synonyms for
            pos_tag: Part of speech tag to constrain synonyms
            
        Returns:
            List of similar words suitable for augmentation
        """     
        topn = max(4, self.max_choices * 3)
        
        candidates = set()
        wordnet_pos = self._get_wordnet_pos(pos_tag) if pos_tag else None

        synsets = wordnet.synsets(word, pos=wordnet_pos)
        if not synsets:
            return []

        # Collect synonyms from WordNet
        for syn in synsets:
            for lemma in syn.lemmas():
                synonym = lemma.name().replace('_', ' ')
                if synonym.lower() != word.lower() and ' ' not in synonym:
                    candidates.add(synonym)
                    if len(candidates) >= topn:
                        break
            if len(candidates) >= topn:
                break

        # Preserve original capitalization
        synonym_list = list(candidates)
        if word[0].isupper():
            synonym_list = [s.capitalize() for s in synonym_list]
            
        # Sample a subset of synonyms
        synonyms_to_return = random.sample(synonym_list, min(topn, len(synonym_list)))

        return synonyms_to_return
    
    def _augment_text(self, text: str) -> str | None:
        """
        Augment a single text by replacing words with X/Y alternatives.
        
        Args:
            text: The text to augment
            
        Returns:
            Augmented text or None if augmentation was not possible
        """
        candidates = self._find_candidates(text)
        if not candidates:
            return None

        # Determine the number of candidates to use
        num_candidates = min(len(candidates), random.randint(1, self.num_words_to_augment))
        candidates = candidates[:num_candidates]

        for candidate in candidates:
            similar_words = self._get_similar_words(candidate[0], candidate[1])
            if not similar_words:
                continue

            # Select a random number of similar words
            num_words = min(len(similar_words), random.randint(1, self.max_choices - 1))
            similar_words = random.sample(similar_words, num_words)
            similar_words.append(candidate[0])
            random.shuffle(similar_words)

            text = text.replace(candidate[0], '/'.join(similar_words))

        return text

    def augment_data(self, data: pd.DataFrame, augment_claim: bool = True, augment_evidence: bool = False) -> None:
        """
        Augment a dataset by adding X/Y alternatives to selected fields.
        
        This method modifies the dataframe in-place, adding alternatives to either
        claims, evidence, or both depending on the parameters.
        
        Args:
            data: DataFrame containing text to augment
            augment_claim: Whether to augment the 'Claim' column
            augment_evidence: Whether to augment the 'Evidence' column
        """
        for index, row in tqdm(data.iterrows(), total=len(data), desc="Augmenting dataset"):
            if augment_claim:
                new_claim = self._augment_text(row['Claim'])
                if new_claim:
                    data.at[index, 'Claim'] = new_claim
            
            if augment_evidence:
                new_evidence = self._augment_text(row['Evidence'])
                if new_evidence:
                    data.at[index, 'Evidence'] = new_evidence

### Full Augmentation pipeline

In [None]:
def generate_augmented_samples(df: pd.DataFrame, label_counts: np.int64, num_samples: int) -> list[pd.DataFrame.index]:
    """
    Generate a list of indices of the samples to augment.

    Args:
        df (pd.DataFrame): The dataframe to augment
        label_counts (np.int64): The number of samples for the label
        num_samples (int): The number of samples to augment

    Returns:
        list[pd.DataFrame.index]: The list of indices of the samples to augment
    """
    indices = []
    
    if num_samples > label_counts:
        full_repeats = num_samples // label_counts
        indices.extend(df.index.repeat(full_repeats))
        num_samples %= label_counts
        
    if num_samples > 0:
        indices.extend(df.sample(num_samples).index)
        
    return indices


async def back_translate_samples(aug_df: pd.DataFrame, label: str) -> pd.DataFrame:
    """
    Back translate the samples for the specified label.

    Args:
        aug_df (pd.DataFrame): The dataframe to augment
        label (str): The label to augment
        
    Returns:
        pd.DataFrame: The augmented dataframe
    """
    src = config.AUGMENTATION_CONFIG[label]["translate"]["src"]

    percentage_to_translate = config.AUGMENTATION_CONFIG[label]["translate"]["percentage"]
    samples = aug_df.sample(frac=percentage_to_translate)

    splits = config.AUGMENTATION_CONFIG[label]["translate"]["split"]
    languages = config.AUGMENTATION_CONFIG[label]["translate"]["intermediates"]

    claim_count = int(len(samples) * splits["Claim"])
    evidence_count = int(len(samples) * splits["Evidence"])

    split_samples = {
        "Claim": samples.iloc[:claim_count],
        "Evidence": samples.iloc[claim_count: claim_count + evidence_count],
        "Both": samples.iloc[claim_count + evidence_count:]
    }

    for text_type, sample in split_samples.items():
        count = 0
        
        for lang, percentage in languages.items():
            # Calculate number of samples for this language
            num_samples = int(len(sample) * percentage)
            
            # Handle the remaining samples if we're at the end
            if count + num_samples >= len(sample):
                aug_df.update(await back_translate_batch(sample.iloc[count:], text_type, src, lang))
                break
            
            # Process the current batch
            current_batch = sample.iloc[count:count + num_samples]
            aug_df.update(await back_translate_batch(current_batch, text_type, src, lang))
            count += num_samples

    return aug_df
    
    
def synonym_replace_samples(aug_df: pd.DataFrame, label: str) -> pd.DataFrame:
    """
    Apply synonym replacement augmentation to the specified samples.
    Modifies the input DataFrame in-place.
    
    Args:
        aug_df (pd.DataFrame): DataFrame containing samples to augment
        label (str): Label identifier ("0" or "1") to get config parameters
        
    Returns:
        pd.DataFrame: The modified input DataFrame
    """
    logging.info(f"Starting synonym replacement for label {label}")
    
    params = config.AUGMENTATION_CONFIG[label]["synonym_replacement"]
    percentage_to_translate = config.AUGMENTATION_CONFIG[label]["synonym_replacement"]["percentage"]
    samples = aug_df.sample(frac=percentage_to_translate)
    
    replacer = AdvancedSynonymReplacerDF(params, samples)
    replacer.augment_data()  # This now modifies aug_df directly
    
    logging.info(f"Completed synonym replacement for label {label}")
    aug_df.update(samples)
    return aug_df  # Return the modified DataFrame


def x_or_y_augment_samples(aug_df: pd.DataFrame, label: str) -> pd.DataFrame:
    """
    Apply x or y augmentation to the specified samples.

    Args:
        aug_df (pd.DataFrame): The dataframe to augment
        label (str): The label to augment
        
    Returns:
        pd.DataFrame: The augmented dataframe
    """
    percentage_to_augment = config.AUGMENTATION_CONFIG[label]["x_or_y"]["percentage"]
    samples = aug_df.sample(frac=percentage_to_augment)

    splits = config.AUGMENTATION_CONFIG[label]["x_or_y"]["split"]
    claim_count = int(len(samples) * splits["Claim"])
    evidence_count = int(len(samples) * splits["Evidence"])

    split_samples = {
        "Claim": samples.iloc[:claim_count],
        "Evidence": samples.iloc[claim_count: claim_count + evidence_count],
        "Both": samples.iloc[claim_count + evidence_count:]
    }

    max_choices = config.AUGMENTATION_CONFIG[label]["x_or_y"]["max_choices"]
    claim_num_words_to_augment = config.AUGMENTATION_CONFIG[label]["x_or_y"]["num_words_to_augment"]["Claim"]
    evidence_num_words_to_augment = config.AUGMENTATION_CONFIG[label]["x_or_y"]["num_words_to_augment"]["Evidence"]

    for text_type, sample in split_samples.items():
        if text_type == "Claim":
            augmenter = XorYAugmenter(sample, max_choices=max_choices, num_words_to_augment=claim_num_words_to_augment)
            augmenter.augment_data(sample, augment_claim=True, augment_evidence=False)
        elif text_type == "Evidence":
            augmenter = XorYAugmenter(sample, max_choices=max_choices, num_words_to_augment=evidence_num_words_to_augment)
            augmenter.augment_data(sample, augment_claim=False, augment_evidence=True)
        elif text_type == "Both":
            augmenter = XorYAugmenter(sample, max_choices=max_choices, num_words_to_augment=min(claim_num_words_to_augment, evidence_num_words_to_augment))
            augmenter.augment_data(sample, augment_claim=True, augment_evidence=True)

    aug_df.update(samples)
    return aug_df


async def main():
    aug_df = pd.read_csv(config.TRAIN_FILE)
    aug_path = config.AUG_TRAIN_FILE

    # get label counts
    label_counts = aug_df['label'].value_counts()
    logging.info(f"Label counts: {label_counts}")

    zeros_to_replace = int(label_counts[0] * config.AUGMENTATION_CONFIG["0"]["replace"])
    ones_to_replace = int(label_counts[1] * config.AUGMENTATION_CONFIG["1"]["replace"])

    zeros_to_add = int(label_counts[0] * config.AUGMENTATION_CONFIG["0"]["add"])
    ones_to_add = int(label_counts[1] * config.AUGMENTATION_CONFIG["1"]["add"])

    logging.info(f"Zeros to replace: {zeros_to_replace}")
    logging.info(f"Ones to replace: {ones_to_replace}")
    logging.info(f"Zeros to add: {zeros_to_add}")
    logging.info(f"Ones to add: {ones_to_add}")

    # get the indices of the zeros and ones to replace
    zeros_to_replace_indices = generate_augmented_samples(aug_df[aug_df['label'] == 0], label_counts[0], zeros_to_replace)
    ones_to_replace_indices = generate_augmented_samples(aug_df[aug_df['label'] == 1], label_counts[1], ones_to_replace)
    zeros_to_add_indices = generate_augmented_samples(aug_df[aug_df['label'] == 0], label_counts[0], zeros_to_add)
    ones_to_add_indices = generate_augmented_samples(aug_df[aug_df['label'] == 1], label_counts[1], ones_to_add)

    #generate addition df
    ones_to_add_df = aug_df.iloc[ones_to_add_indices].reset_index(drop=True).copy()
    zeros_to_add_df = aug_df.iloc[zeros_to_add_indices].reset_index(drop=True).copy()

    # back translation
    await back_translate_samples(zeros_to_add_df, "0")
    await back_translate_samples(ones_to_add_df, "1")

    # synonym replacement
    synonym_replace_samples(zeros_to_add_df, "0")
    synonym_replace_samples(ones_to_add_df, "1")

    # x or y augmentation
    x_or_y_augment_samples(zeros_to_add_df, "0")
    x_or_y_augment_samples(ones_to_add_df, "1")    

    aug_df = pd.concat([aug_df, zeros_to_add_df, ones_to_add_df])

    aug_df.to_csv(aug_path, index=False)

if __name__ == "__main__":
    asyncio.run(main())

## SVM Training

### Helper functions

In [None]:
from gensim.downloader import load as glove_embeddings_loader


def get_memory_usage() -> float:
    """
    Get current memory usage of the process.
    
    Returns:
        Memory usage in megabytes (MB)
    """
    process = psutil.Process()
    return process.memory_info().rss / (1024 * 1024)

@contextmanager
def timer(name: str, logger):
    """
    Context manager for timing code execution.
    
    Args:
        name: Descriptive name for the operation being timed
        logger: Logger object to output timing information
    
    Example:
        with timer("Data processing", logger):
            process_data()
    """
    start_time = time.time()
    try:
        yield
    finally:
        end_time = time.time()
        logger.info(f"{name} completed in {end_time - start_time:.2f} seconds")

# Define cache directory and path
CACHE_DIR = config.DATA_DIR.parent / "cache"
EMBEDDINGS_CACHE_PATH = CACHE_DIR / 'glove_embeddings.pkl'

def load_cached_embeddings(embedding_dim=300):
    """
    Load GloVe embeddings of specified dimension from cache if available, otherwise download and cache them.
    
    Args:
        embedding_dim (int): Desired dimension for GloVe embeddings (50, 100, 200, or 300). Defaults to 300.
    
    Returns:
        dict: GloVe word embeddings dictionary.
    """
    
    # Create cache directory if it doesn't exist
    CACHE_DIR.mkdir(exist_ok=True, parents=True)
    
    cache_path = CACHE_DIR / f'glove_embeddings_{embedding_dim}.pkl'
    if cache_path.exists():
        logging.info(f"Loading GloVe embeddings from cache: {cache_path}")
        with open(cache_path, 'rb') as f:
            glove_embeddings = pickle.load(f)
    else:
        model_name = f'glove-wiki-gigaword-{embedding_dim}'
        logging.info(f"Downloading GloVe embeddings with model {model_name} (this might take a while)...")
        glove_embeddings = glove_embeddings_loader(model_name)
        
        # Cache the embeddings for future use
        logging.info(f"Caching GloVe embeddings to: {cache_path}")
        with open(cache_path, 'wb') as f:
            pickle.dump(glove_embeddings, f)
    
    return glove_embeddings


### Data Preparation for SVM

In [None]:
def prepare_svm_data(data: pd.DataFrame, 
                    remove_stopwords: bool = True, 
                    lemmatize: bool = True, 
                    min_freq: int = 2, 
                    vocab_size: Optional[int] = None) -> Tuple[pd.DataFrame, np.ndarray, Set[str]]:
    """
    Prepare text data for the SVM by cleaning, normalizing and vocabulary management.
    
    Args:
        data: DataFrame containing 'Claim' and 'Evidence' columns
        remove_stopwords: Whether to remove common stopwords
        lemmatize: Whether to apply lemmatization
        min_freq: Minimum frequency for words to be included in vocabulary
        vocab_size: Maximum vocabulary size (most frequent words kept)
    
    Returns:
        Tuple containing:
            - Processed DataFrame with added 'text' column
            - NumPy array of labels
            - Set of vocabulary words
    """
    translator = str.maketrans('', '', string.punctuation)

    def clean_text(text: str) -> str:
        """
        Clean and normalize text by lowercasing, removing punctuation,
        and optionally removing stopwords and lemmatizing.
        """
        text = text.lower().translate(translator)
        # Normalize whitespace
        text = " ".join(text.split())
        
        if remove_stopwords:
            try:
                # Keep important discourse markers and modal verbs
                keep_words = {
                    'because', 'since', 'therefore', 'hence', 'thus', 'although',
                    'however', 'but', 'not', 'should', 'must', 'might', 'may',
                    'could', 'would', 'against', 'between', 'before', 'after'
                }
                custom_stopwords = set(stopwords.words("english")) - keep_words
                
                text = " ".join([word for word in text.split() 
                               if word not in custom_stopwords])
            except Exception:
                pass
            
        if lemmatize:
            try:
                lemmatizer = WordNetLemmatizer()
                words = text.split()
                text = " ".join([lemmatizer.lemmatize(word) for word in words])
            except Exception:
                pass
        return text

    # Build vocabulary from training data
    train_samples = pd.concat([data['Claim'], data['Evidence']]).apply(clean_text)
    all_words = [word for text in train_samples for word in text.split()]
    word_counts = Counter(all_words)

    # Filter words by minimum frequency and sort by frequency
    filtered_words = [(word, count) for word, count in word_counts.items() if count >= min_freq]
    sorted_words = sorted(filtered_words, key=lambda x: (-x[1], x[0]))
    
    # Apply vocabulary size limit if specified
    if vocab_size is not None:
        sorted_words = sorted_words[:vocab_size]
    
    vocab = {word for word, _ in sorted_words}

    def replace_rare_words(text: str) -> str:
        """Replace words not in vocabulary with <UNK> token."""
        return ' '.join([word if word in vocab else '<UNK>' for word in text.split()])

    # Process the data with UNK replacement
    data['text'] = ("Claim: " + data['Claim'].apply(clean_text).apply(replace_rare_words) + 
                    " [SEP] " + "Evidence: " + data['Evidence'].apply(clean_text).apply(replace_rare_words))

    # Extract labels
    labels = data['label'].values

    return data, labels, vocab

### Glove Vectorizer

In [None]:
from sklearn.base import BaseEstimator, TransformerMixin
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer

class GloveVectorizer(BaseEstimator, TransformerMixin):
    """
    A vectorizer that combines GloVe word embeddings with positional encoding.
    
    This vectorizer transforms text into fixed-size vectors by:
    1. Converting words to GloVe embeddings
    2. Applying TF-IDF weighting (optional)
    3. Adding positional encoding information
    4. Computing interaction features between claim and evidence
    """
    
    def __init__(self, sep_token: str = '[SEP]', use_tfidf_weighting=True, vocabulary=None, 
                 embedding_dim=300, ngram_range=(1,1), min_df=2, max_df=0.95):
        """
        Initialize the GloveVectorizer.
        
        Args:
            sep_token: Token used to separate claim and evidence. Defaults to '[SEP]'.
            use_tfidf_weighting: Whether to use TF-IDF weights for word embeddings. Defaults to True.
            vocabulary: Set of words to include in the vocabulary.
            embedding_dim: Desired embedding dimension.
            ngram_range: The lower and upper boundary of the n-grams to be extracted.
            min_df: Minimum document frequency for TF-IDF.
            max_df: Maximum document frequency for TF-IDF.
        """
        self.glove = load_cached_embeddings(embedding_dim)
        self.vector_size = embedding_dim
        self.sep_token = sep_token
        self.use_tfidf_weighting = use_tfidf_weighting
        self.vocabulary = vocabulary or set()
        self.ngram_range = ngram_range
        
        # Initialize TF-IDF vectorizer if weighting is enabled
        self.tfidf_vectorizer = TfidfVectorizer(
            min_df=min_df, 
            max_df=max_df,
            vocabulary=self.vocabulary,
            max_features=len(self.vocabulary) if self.vocabulary else None,
            ngram_range=self.ngram_range
        ) if use_tfidf_weighting else None
        
    @staticmethod
    def _pre_process(doc: str) -> str:
        """
        Pre-process text by removing unrepresentable characters and quotes.
        
        Args:
            doc: Input text document
            
        Returns:
            Pre-processed text with ASCII-only characters and no leading/trailing quotes
        """
        # Remove any unrepresentable characters
        doc = doc.encode('ascii', 'ignore').decode('ascii')
        # Remove any double quotes at the beginning and end of the document
        doc = doc.strip('"')
        return doc
    
    def _get_weighted_vector(self, text: str, tfidf_weights=None) -> np.ndarray:
        """
        Compute weighted average of word vectors using vocabulary words.
        
        Args:
            text: Input text
            tfidf_weights: Dictionary mapping words to their TF-IDF weights
            
        Returns:
            Weighted average vector of the input text's words
        """
        # Replace OOV words with UNK before processing
        words = [word if word in self.vocabulary else '<UNK>' for word in text.split()]
        
        if not words:
            return np.zeros(self.vector_size)
            
        # Restrict to vocabulary words
        valid_words = [word for word in words if word in self.vocabulary]
        
        if self.use_tfidf_weighting and tfidf_weights:
            vectors = []
            weights = []
            for word in valid_words:
                if word in self.glove:
                    vectors.append(self.glove[word])
                    weights.append(tfidf_weights.get(word, 1.0))  # Default weight=1 if not in TF-IDF
            
            if vectors:
                weights = np.array(weights) / np.sum(weights)  # Normalize weights
                return np.average(vectors, axis=0, weights=weights)
        
        # Fallback to regular mean if no weights or no matching words
        vectors = [self.glove[word] for word in valid_words if word in self.glove]
        if vectors:
            return np.mean(vectors, axis=0)
            
        return np.zeros(self.vector_size)
    
    def _get_positional_encoding(self, max_len: int, d_model: int) -> np.ndarray:
        """
        Generate sinusoidal positional encoding matrix.
        
        Implements the positional encoding from "Attention Is All You Need" (Vaswani et al., 2017).
        The encoding allows the model to learn to attend by relative positions.
        
        Args:
            max_len: Maximum sequence length to encode
            d_model: Dimensionality of the model/embeddings
            
        Returns:
            A matrix of shape (max_len, d_model) with positional encodings.
        """
        # Pre-compute position and dimension arrays
        positions = np.arange(max_len)[:, np.newaxis]  # Shape: (max_len, 1)
        
        # Create the div_term with proper shape for broadcasting
        div_term = np.exp(-(np.log(10000.0) / d_model) * np.arange(0, d_model, 2))
        
        # Initialize the encoding matrix
        pe = np.zeros((max_len, d_model))
        
        # Set even indices to sine, odd indices to cosine
        pe[:, 0::2] = np.sin(positions * div_term)
        if d_model > 1:  # Handle case where d_model might be 1
            pe[:, 1::2] = np.cos(positions * div_term[:pe.shape[1]//2])
            
        return pe

    def _extract_positional_features(self, text: str) -> np.ndarray:
        """
        Extract features using sinusoidal positional encoding.
        
        Applies positional encoding to word vectors to capture sequential information
        in the input text.
        
        Args:
            text: Input text string
            
        Returns:
            A vector of size self.vector_size with positionally encoded features
        """
        words = text.split()
        if not words:
            return np.zeros(self.vector_size)
            
        # Get word vectors for words that exist in the embeddings
        word_vectors = []
        for word in words:
            if word in self.glove:
                word_vectors.append(self.glove[word])
        
        if not word_vectors:
            return np.zeros(self.vector_size)
            
        # Stack word vectors into a matrix: shape (sequence_length, embedding_dim)
        word_vectors = np.stack(word_vectors)
        sequence_length = word_vectors.shape[0]
        
        # Generate positional encoding of appropriate size
        pe = self._get_positional_encoding(sequence_length, self.vector_size)
        
        # Apply positional encoding to word vectors
        positionally_encoded = word_vectors + pe[:sequence_length]
        
        # Return mean of positionally encoded vectors
        return np.mean(positionally_encoded, axis=0)
    
    def fit(self, X, y=None):
        """
        Fit the vectorizer by preparing TF-IDF weights if enabled.
        
        Args:
            X: Training data. Each element should be a string containing
               claim and evidence separated by self.sep_token.
            y: Target values. Not used in this vectorizer.
            
        Returns:
            self: Returns the instance itself.
        """
        if self.use_tfidf_weighting:
            # Fit TF-IDF vectorizer on all texts
            self.tfidf_vectorizer.fit([doc.replace(self.sep_token, " ") for doc in X])
        return self
    
    def transform(self, X):
        """
        Transform the input data into feature vectors.
        
        For each input text, this method:
        1. Splits the text into claim and evidence
        2. Computes weighted word embeddings for both parts
        3. Adds positional encoding information
        4. Computes interaction features between claim and evidence
        
        Args:
            X: Input data. Each element should be a string containing
               claim and evidence separated by self.sep_token.
            
        Returns:
            Feature matrix with claim, evidence, positional, and interaction features
        """
        doc_vectors = []
        
        # Compute TF-IDF for all documents if using weighting
        tfidf_weights_dict = {}
        if self.use_tfidf_weighting:
            # Get TF-IDF vocabulary and weights
            vocabulary = self.tfidf_vectorizer.vocabulary_
            idf = self.tfidf_vectorizer.idf_
            
            # Create a lookup dictionary for word -> tfidf weight
            tfidf_weights_dict = {word: idf[idx] for word, idx in vocabulary.items()}
        
        for doc in X:
            # Split on [SEP] token to separate claim and evidence
            try:
                claim, evidence = doc.split(self.sep_token)
            except ValueError as ve:
                raise ValueError(f"Document splitting error: Expected 2 parts separated by '{self.sep_token}', but got an error: {ve}")
            
            # Pre-process the claim and evidence
            claim = self._pre_process(claim)
            evidence = self._pre_process(evidence)
            
            # Get weighted vectors for claim and evidence
            claim_vector = self._get_weighted_vector(claim, tfidf_weights_dict)
            evidence_vector = self._get_weighted_vector(evidence, tfidf_weights_dict)
            
            # Get positional features
            claim_pos_features = self._extract_positional_features(claim)
            evidence_pos_features = self._extract_positional_features(evidence)
            
            # Prepare interaction features
            element_wise_product = claim_vector * evidence_vector
            absolute_difference = np.abs(claim_vector - evidence_vector)
            
            # Concatenate all features
            doc_vectors.append(np.concatenate([
                claim_vector, 
                evidence_vector,
                claim_pos_features,
                evidence_pos_features,
                element_wise_product,
                absolute_difference
            ]))
            
        return np.array(doc_vectors)

### Feature Extractor

In [None]:
from nltk.sentiment import SentimentIntensityAnalyzer
import nltk
from scipy.spatial.distance import cosine

class FeatureExtractor:
    """
    A feature extractor that combines various text-based features for evidence detection.
    
    This extractor computes a rich set of features including:
    1. Basic text statistics (lengths, word counts, etc.)
    2. Sentiment analysis features using VADER
    3. Text characteristics (capitalization, punctuation, digits)
    4. TF-IDF based similarity between claim and evidence
    
    The features are designed to capture both semantic and structural aspects
    of the text, which are important for evidence detection tasks.
    
    Attributes:
        sentiment_analyzer (SentimentIntensityAnalyzer): VADER sentiment analyzer
        tfidf (TfidfVectorizer): TF-IDF vectorizer for computing text similarity
    """
    
    def __init__(self):
        """
        Initialize the FeatureExtractor.
        
        Downloads required NLTK resources if not already present:
        - vader_lexicon: For sentiment analysis
        - punkt: For text tokenization
        """
        nltk.download('vader_lexicon')
        nltk.download('punkt')
        self.sentiment_analyzer = SentimentIntensityAnalyzer()
        self.tfidf = TfidfVectorizer(max_features=100, stop_words='english')
        
    def transform(self, X):
        """
        Transform input texts into feature vectors.
        
        For each input text (containing claim and evidence), computes:
        1. Basic text statistics:
           - Word overlap between claim and evidence
        
        2. Sentiment features:
           - Negative, neutral, positive scores for both claim and evidence
           - Compound sentiment scores
           - Absolute difference in sentiment between claim and evidence
        
        3. TF-IDF similarity between claim and evidence
        
        Args:
            X (array-like): Input texts. Each element should be a string containing
                          claim and evidence separated by '[SEP]'.
            
        Returns:
            pd.DataFrame: Feature matrix with sentiment and similarity features.
        """
        features = []
        
        for text in X:
            claim, evidence = text.split("[SEP]")
            
            # Extract sentiment features
            claim_sentiments = self.sentiment_analyzer.polarity_scores(claim)
            evidence_sentiments = self.sentiment_analyzer.polarity_scores(evidence)
            
            # Create feature dictionary
            feature_dict = {
                'word_overlap': len(set(claim.split()) & set(evidence.split())),
                'claim_sentiment_neg': claim_sentiments['neg'],
                'claim_sentiment_neu': claim_sentiments['neu'],
                'claim_sentiment_pos': claim_sentiments['pos'],
                'claim_sentiment_compound': claim_sentiments['compound'],
                'evidence_sentiment_neg': evidence_sentiments['neg'],
                'evidence_sentiment_neu': evidence_sentiments['neu'],
                'evidence_sentiment_pos': evidence_sentiments['pos'],
                'evidence_sentiment_compound': evidence_sentiments['compound'],
                'sentiment_diff': abs(claim_sentiments['compound'] - evidence_sentiments['compound'])
            }
            
            # Calculate TF-IDF similarity
            claim_tfidf = self.tfidf.transform([claim]).toarray()[0]
            evidence_tfidf = self.tfidf.transform([evidence]).toarray()[0]
            
            # Calculate cosine similarity only if vectors are non-zero
            if np.sum(claim_tfidf) > 0 and np.sum(evidence_tfidf) > 0:
                tfidf_similarity = 1 - cosine(claim_tfidf, evidence_tfidf)
            else:
                tfidf_similarity = 0
                
            feature_dict['tfidf_similarity'] = tfidf_similarity
            features.append(feature_dict)
            
        return pd.DataFrame(features)    
    
    def fit(self, X, y=None):
        """
        Fit the feature extractor by preparing TF-IDF weights.
        
        This method fits the TF-IDF vectorizer on all claims and evidence texts
        to prepare for computing similarity features during transform.
        
        Args:
            X (array-like): Training data. Each element should be a string containing
                          claim and evidence separated by '[SEP]'.
            y (array-like, optional): Target values. Not used in this extractor.
            
        Returns:
            self: Returns the instance itself.
        """
        # Extract all texts for TF-IDF fitting
        all_texts = []
        for text in X:
            claim, evidence = text.split("[SEP]")
            all_texts.append(claim)
            all_texts.append(evidence)
        
        # Fit TF-IDF on all texts
        self.tfidf.fit(all_texts)
        return self


### Tuning/Training Pipeline

In [None]:
# Configure logging
logger = logging.getLogger(__name__)

# Configuration Constants
NUM_TRIALS = 100
TRAIN_SUBSET_FRACTION = 0.4  # Use 40% of training data for faster iteration

# Load and prepare initial data
initial_memory = get_memory_usage()
logger.info(f"Initial memory usage: {initial_memory:.2f} MB")

train_df_raw = pd.read_csv(config.AUG_TRAIN_FILE)
dev_df_raw = pd.read_csv(config.DEV_FILE)



def calculate_all_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
    """
    Calculate comprehensive evaluation metrics for classification.
    
    Args:
        y_true: Array of true labels
        y_pred: Array of predicted labels
    
    Returns:
        Dictionary containing accuracy, precision, recall, F1-score, and MCC metrics
    """
    # Basic accuracy
    accuracy = accuracy_score(y_true, y_pred)
    
    # Calculate precision, recall, f1 (macro)
    macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average='macro'
    )
    
    # Calculate precision, recall, f1 (weighted)
    weighted_precision, weighted_recall, weighted_f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average='weighted'
    )
    
    # Matthews Correlation Coefficient
    mcc = matthews_corrcoef(y_true, y_pred)
    
    metrics = {
        'Accuracy': accuracy,
        'Macro-P': macro_precision,
        'Macro-R': macro_recall,
        'Macro-F1': macro_f1,
        'W Macro-P': weighted_precision,
        'W Macro-R': weighted_recall,
        'W Macro-F1': weighted_f1,
        'MCC': mcc
    }
    
    return metrics

# Create training subset for faster iteration if needed
if TRAIN_SUBSET_FRACTION < 1.0:
    train_samples = int(len(train_df_raw) * TRAIN_SUBSET_FRACTION)
    logger.info(f"Using {TRAIN_SUBSET_FRACTION*100:.0f}% of training data ({train_samples} samples)")
    
    train_df_subset, _ = train_test_split(
        train_df_raw, 
        train_size=TRAIN_SUBSET_FRACTION, 
        stratify=train_df_raw['label'],
        random_state=42
    )
    logger.info(f"Stratified subset created. Label distribution:\n{train_df_subset['label'].value_counts(normalize=True)}")
else:
    logger.info("Using 100% of the training data.")
    train_df_subset = train_df_raw


def objective(trial: optuna.Trial) -> float:
    """
    Optuna objective function for hyperparameter optimization.
    
    Generates features, trains SVM model, and evaluates performance for each trial.
    
    Args:
        trial: Current Optuna trial object
        
    Returns:
        float: Weighted Macro F1 score on development set
    """
    global train_df_subset, dev_df_raw
    trial_number = trial.number
    
    logger.info(f"\n{'='*50}\nStarting trial {trial_number}/{NUM_TRIALS}\n{'='*50}")
    trial_start = time.time()
    
    # Define hyperparameter search space
    params = {
        "C": trial.suggest_float("C", 1.5, 2.5), # Focused range around C=2
        "vocab_size": trial.suggest_int("vocab_size", 10000, 20000, step=500), # Tune vocab size,
        "embedding_dim": trial.suggest_categorical ("embedding_dim", [100, 200, 300]), # Tune embedding dimension
        "pca_components": trial.suggest_int("pca_components", 400, 600, step=10), # Tune PCA components
        # "min_df": trial.suggest_int("min_df", 1, 2),
        # "max_df": trial.suggest_float("max_df", 0.85, 1.0, step=0.05),
    }
        
    # Data preparation for this trial
    with timer(f"Trial {trial_number} Data Prep", logger):
        train_df_trial, train_labels_trial, trial_vocab = prepare_svm_data(
            train_df_subset.copy(), 
            remove_stopwords=True, 
            lemmatize=True, 
            min_freq=2, 
            vocab_size=params['vocab_size']
        )
        
        dev_df_trial, dev_labels_trial, _ = prepare_svm_data(
            dev_df_raw.copy(), 
            remove_stopwords=True, 
            lemmatize=True, 
            min_freq=2, 
            vocab_size=params['vocab_size']
        ) 
        
        train_texts_trial = train_df_trial['text'].tolist()
        dev_texts_trial = dev_df_trial['text'].tolist()
        logger.info(f"  Trial vocab size: {len(trial_vocab)}")
    
    # Create pipeline and train model
    pipeline = create_pipeline_from_params(params, trial_vocab)
    
    with timer(f"Trial {trial_number} training", logger):
        pipeline.fit(train_texts_trial, train_labels_trial)
    
    # Evaluate model
    with timer(f"Trial {trial_number} evaluation", logger):
        dev_preds = pipeline.predict(dev_texts_trial)
        metrics = calculate_all_metrics(dev_labels_trial, dev_preds)
    
    # Save trial results
    svm_dir = config.SAVE_DIR / "svm"
    svm_dir.mkdir(parents=True, exist_ok=True)
    
    with (svm_dir / f'svm_{trial_number}.json').open('w') as f:
        serializable_params = {k: (float(v) if isinstance(v, np.floating) else v) for k, v in params.items()}
        serializable_metrics = {k: (float(v) if isinstance(v, np.floating) else v) for k, v in metrics.items()}
        json.dump({**serializable_metrics, **serializable_params}, f)
    
    trial_duration = time.time() - trial_start
    logger.info(f"Trial {trial_number} completed in {trial_duration:.2f} seconds")
    logger.info(f"Trial {trial_number} results: W Macro-F1 = {metrics['W Macro-F1']:.4f}")
    
    # Free memory
    gc.collect()
    return metrics["W Macro-F1"]


def create_pipeline_from_params(params: Dict, vocabulary: List[str]) -> Pipeline:
    """
    Create scikit-learn pipeline for SVM classification.
    
    Builds a pipeline with feature extraction, scaling, dimensionality reduction,
    and SVM classification components based on specified hyperparameters.
    
    Args:
        params: Dictionary of hyperparameters
        vocabulary: List of vocabulary terms to use in vectorization
        
    Returns:
        Pipeline: Scikit-learn pipeline for text classification
    """
    pipeline_steps = []
    
    # Feature extraction component
    pipeline_steps.append(('glove_feature_union', FeatureUnion([
        ('glove', GloveVectorizer(
            use_tfidf_weighting=True,
            vocabulary=vocabulary,
            embedding_dim=params['embedding_dim'],
            ngram_range=(1, 2),
            min_df=1,
            max_df=0.95
        )),
        ('feature_extractor', FeatureExtractor())
    ])))
    
    # Feature scaling and dimensionality reduction
    pipeline_steps.append(('scaler', StandardScaler()))
    pipeline_steps.append(('pca', PCA(n_components=params['pca_components'])))
    
    # SVM classifier with RBF kernel
    pipeline_steps.append(('svm', SVC(
        C=params['C'],
        kernel='rbf',
        gamma='scale',
        probability=False,
        random_state=42
    )))
    
    return Pipeline(pipeline_steps)


def hyperparameter_tuning(show_plots: bool = False) -> Dict:
    """
    Perform hyperparameter tuning using Optuna.
    
    Configures and runs an Optuna study to optimize hyperparameters for the SVM model.
    
    Args:
        show_plots: Whether to display optimization visualizations
        
    Returns:
        Dict: Best hyperparameters found during optimization
    """
    logger.info(f"Running {NUM_TRIALS} hyperparameter optimization trials...")
    
    # Configure Optuna sampler and pruner
    sampler = TPESampler(
        seed=42,
        n_startup_trials=int(NUM_TRIALS / 10),
        multivariate=True,
        constant_liar=True
    )
    pruner = MedianPruner(n_startup_trials=5, n_warmup_steps=5, interval_steps=2)
    
    study = optuna.create_study(
        direction='maximize',
        sampler=sampler,
        pruner=pruner,
        study_name='svm_evidence_detection_streamlined'
    )
    
    try:
        with timer("Hyperparameter optimization", logger):
            study.optimize(objective, n_trials=NUM_TRIALS, n_jobs=8)
    except KeyboardInterrupt:
        logger.warning("Hyperparameter tuning interrupted by user.")
    
    # Log best trial results
    if not study.trials:
        logger.error("No trials completed. Exiting.")
        return {}
        
    trial = study.best_trial
    logger.info("\nBest trial:")
    logger.info(f"  Value (W Macro-F1): {trial.value:.4f}")
    logger.info("  Params:")
    
    for key, value in trial.params.items():
        logger.info(f"    {key}: {value}")

    # Generate and display plots
    if show_plots:
        try:
            logger.info("Generating Optuna trial plots...")
            history_fig = optuna.visualization.plot_optimization_history(study)
            show(history_fig)
            importance_fig = optuna.visualization.plot_param_importances(study)
            show(importance_fig)
            heatmap_fig = optuna.visualization.plot_contour(study)
            show(heatmap_fig)
        except Exception as e:
            logger.error(f"Failed to generate plots: {e}")

    return trial.params


def main() -> None:
    """
    Main execution function for SVM model training and optimization.
    
    Performs hyperparameter tuning, trains the final model with optimal 
    parameters on full dataset, and saves the model.
    """
    global train_df_raw, dev_df_raw
    
    logger.info("\n" + "="*70)
    logger.info("EVIDENCE DETECTION SVM MODEL TRAINING")
    logger.info("="*70)
    
    # Find optimal hyperparameters
    params = hyperparameter_tuning(show_plots=True)
    
    # Process data with optimal parameters
    train_df_processed, train_labels, best_vocab = prepare_svm_data(
        train_df_raw, 
        remove_stopwords=True, 
        lemmatize=True, 
        min_freq=2, 
        vocab_size=params['vocab_size']
    )
    
    dev_df_processed, dev_labels, _ = prepare_svm_data(
        dev_df_raw, 
        remove_stopwords=True, 
        lemmatize=True, 
        min_freq=2, 
        vocab_size=params['vocab_size']
    )

    # Train and evaluate final model
    logger.info("\n" + "="*70)
    logger.info("TRAINING FINAL FULL MODEL")
    logger.info("="*70)

    pipeline = create_pipeline_from_params(params, best_vocab)
    pipeline.fit(train_df_processed['text'], train_labels)
    dev_preds = pipeline.predict(dev_df_processed['text'])
    metrics = calculate_all_metrics(dev_labels, dev_preds)
    logger.info(f"Final model evaluation: {metrics}")

    # Save the trained model
    pipeline_pickle_path = config.SAVE_DIR / "svm" / "svm_pipeline.pkl"
    try:
        with open(pipeline_pickle_path, "wb") as f:
            pickle.dump(pipeline, f)
        logger.info(f"Pipeline successfully saved to {pipeline_pickle_path}")
    except Exception as e:
        logger.error(f"Error saving pipeline: {e}")


if __name__ == "__main__":
    main()

### Evaluation

In [None]:
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    matthews_corrcoef,
    confusion_matrix,
)
from matplotlib import pyplot as plt

def compute_metrics(eval_pred):
    """
    Calculate evaluation metrics for classification.
    
    Args:
        eval_pred: Tuple of predictions and labels
        
    Returns:
        dict: Dictionary of metrics including accuracy, precision, recall, F1, and MCC
    """
    predictions, labels = eval_pred
    
    predictions = predictions.argmax(axis=1)
    
    accuracy = accuracy_score(labels, predictions)
    
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average=None, zero_division=0
    )
    
    weighted_precision, weighted_recall, weighted_f1, _ = precision_recall_fscore_support(
        labels, predictions, average='weighted', zero_division=0
    )
    
    mcc = matthews_corrcoef(labels, predictions)
    
    metrics = {
        'Accuracy': accuracy,
        'Positive_Precision': precision[1] if len(precision) > 1 else 0,
        'Positive_Recall': recall[1] if len(recall) > 1 else 0,
        'Positive_F1': f1[1] if len(f1) > 1 else 0,
        'W Macro-P': weighted_precision,
        'W Macro-R': weighted_recall,
        'W Macro-F1': weighted_f1,
        'MCC': mcc
    }
    
    return metrics

def plot_confusion_matrix(y_true, y_pred, save_path):
    """
    Plot and save confusion matrix.
    
    Args:
        y_true: True labels
        y_pred: Predicted labels
        save_path: Path to save the confusion matrix plot
    """
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    
    classes = ['Negative', 'Positive']
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    thresh = cm.max() / 2.
    for i, j in np.ndindex(cm.shape):
        plt.text(j, i, f'{cm[i, j]}\n({cm_norm[i, j]:.2f})',
                horizontalalignment="center",
                color="white" if cm[i, j] > thresh else "black")
    
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.savefig(save_path)
    plt.close()