# Anomaly Detection with Gemini API

This notebook demonstrates how to detect anomalies in text data using **embeddings** generated by the Gemini API.

## 💡 Capabilities of this Notebook
- Run in **Kaggle or Google Colab** seamlessly
- Loads and preprocesses newsgroup text data from multiple domains
- Generates text embeddings using Google's Gemini API
- Defines a subset of text (**science newsgroups**) as 'normal' data
- Introduces anomalies by **mixing in unrelated categories in test data**
- Calculates semantic distances using embeddings
- **Computes/Finds anomaly text in test data**


## 🔧 Possible Enhancements
- Try different distance metrics (cosine, Mahalanobis)
- Use alternative embedding models (e.g., BERT, OpenAI)
- Introduce real-world noisy or domain-shifted data
- Save and reuse embeddings to avoid repeated API calls

In [14]:
# Remove unused conflicting packages
#!pip uninstall -qqy jupyterlab kfp 2>/dev/null
# Install specific google-genai version used in the original notebook
!pip install -U -q "google-genai==1.7.0"

## 📚 Import Necessary Libraries

We import the required Python modules for:
- Data loading and preprocessing
- Using the Gemini API to embed text
- Distance computation for anomaly scoring


In [15]:
# Load newsgroup dataset with selected categories
from google import genai
from google.genai import types
from google.api_core import retry
import google.api_core.exceptions # Often needed for specific error types

print(f"Using google-genai version: {genai.__version__}")

# Import necessary libraries
import numpy as np
import pandas as pd
import tensorflow as tf # Still needed for random seed setting initially
import time
import concurrent.futures
from tqdm.rich import tqdm as tqdmr
import warnings
import email # Standard Python library for parsing email messages
import re # Regular expressions for pattern matching and text cleaning
from sklearn.datasets import fetch_20newsgroups
from sklearn.metrics.pairwise import cosine_distances # For distance calculation
from sklearn.utils import Bunch # To help manage data sourcing workaround


# For demonstration purposes, we use a random seed for reproducibility.
np.random.seed(42)
tf.random.set_seed(42) # Keep for numpy seed consistency if needed elsewhere

# Define the retry predicate function (checks for rate limit or server errors)
# Ensure this runs before functions using the @retry decorator
is_retriable = lambda e: isinstance(e, (
    genai.errors.APIError, # General API errors (includes 503)
    google.api_core.exceptions.ResourceExhausted, # Specific error for 429
    google.api_core.exceptions.DeadlineExceeded,
    google.api_core.exceptions.ServiceUnavailable # Handles 503
)) and (not hasattr(e, 'code') or e.code in {429, 503}) # Check code if available

print("Imports and retry logic set up.")

Using google-genai version: 1.7.0
Imports and retry logic set up.


## 🔐 Set Up Gemini API Key

To use the Gemini API for generating embeddings, an API key is required.

### Options:
- **Kaggle**: Secrets are stored under the Kaggle environment and fetched automatically.
- **Colab**: Use `getpass` or manually input your key. Alternatively, you can use `os.environ`.

> 💡 Make sure your API key has access to the embedding endpoint.

In [16]:
# Authenticate and create embedding model client
import os
# Make sure genai is imported if not done earlier in Cell 3
# import google.generativeai as genai

# --- Auto-detect Environment and Get API Key ---
client = None # Initialize client to None
GOOGLE_API_KEY = None
environment = "unknown"

print("Attempting to detect environment and configure Google GenAI client...")

# --- Environment Detection using Environment Variables ---
if 'COLAB_GPU' in os.environ:
    print("Detected Colab environment via COLAB_GPU.")
    environment = "colab"
elif 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
    print("Detected Kaggle environment via KAGGLE_KERNEL_RUN_TYPE.")
    environment = "kaggle"
else:
    # Fallback check using imports if variables aren't definitive
    try:
        from google.colab import userdata
        print("Detected Colab environment via import.")
        environment = "colab"
    except ImportError:
        try:
            from kaggle_secrets import UserSecretsClient
            print("Detected Kaggle environment via import.")
            environment = "kaggle"
        except ImportError:
             print("Could not detect Colab or Kaggle environment via variables or imports.")
             environment = "other"


# --- Get API Key based on detected environment ---
if environment == "colab":
    try:
        from google.colab import userdata # Import again just in case
        GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')
        print("Successfully retrieved GOOGLE_API_KEY from Colab secrets.")
    except userdata.SecretNotFoundError:
        print("Secret 'GOOGLE_API_KEY' not found in Colab secrets.")
    except Exception as e:
        print(f"An error occurred retrieving Colab secret: {type(e).__name__}: {e}")

elif environment == "kaggle":
    try:
        from kaggle_secrets import UserSecretsClient # Import again just in case
        GOOGLE_API_KEY = UserSecretsClient().get_secret("GOOGLE_API_KEY")
        print("Successfully retrieved GOOGLE_API_KEY from Kaggle secrets.")
    except Exception as e: # Catch potential errors during secret retrieval
         print(f"An error occurred retrieving Kaggle secret: {type(e).__name__}: {e}")
         # Check if it's specifically the secret not found error if possible
         if "Secret not found" in str(e): # Simple string check
              print("Secret 'GOOGLE_API_KEY' not found in Kaggle secrets.")
         else:
              print("Please ensure the secret 'GOOGLE_API_KEY' is added to this notebook.")

elif environment == "other":
     # Try environment variable as a last resort
     print("Trying OS environment variable 'GOOGLE_API_KEY'.")
     GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
     if GOOGLE_API_KEY:
          print("Found GOOGLE_API_KEY in OS environment variables.")
     else:
          print("GOOGLE_API_KEY not found as OS environment variable.")
else: # Should not happen, but handle unknown case
     print("Environment detection resulted in an unexpected state.")


# --- Initialize Client ---
if GOOGLE_API_KEY:
    try:
        # Ensure genai.Client is available
        if hasattr(genai, 'Client'):
             client = genai.Client(api_key=GOOGLE_API_KEY)
             print("Successfully configured Google GenAI client.")
             # Optional: Test client
             # try:
             #      client.models.list()
             #      print("Client connection test successful.")
             # except Exception as test_e:
             #      print(f"Client connection test failed: {test_e}")
             #      client = None # Reset client if test fails
        else:
             print("Error: genai.Client class not found. Was the library imported correctly?")
             client = None

    except Exception as client_e:
        print(f"\n--- ERROR: Failed to configure client ---")
        print(f"An error occurred during client configuration: {type(client_e).__name__}: {client_e}")
        client = None
else:
    print("\n--- WARNING ---")
    print("GOOGLE_API_KEY could not be retrieved.")
    print("GenAI Client could not be configured.")
    print("API calls to Gemini will FAIL.")
    print("--- END WARNING ---")




Attempting to detect environment and configure Google GenAI client...
Detected Kaggle environment via KAGGLE_KERNEL_RUN_TYPE.
Successfully retrieved GOOGLE_API_KEY from Kaggle secrets.
Successfully configured Google GenAI client.


## Dataset

* The [20 Newsgroups Text Dataset](https://scikit-learn.org/0.19/datasets/twenty_newsgroups.html) is used as the source for our 'normal' data.
* We load the raw data, preprocess it, and then sample subsets from 'sci.*' categories to define our 'normal' data group.

In [17]:
# Load newsgroup dataset with selected categories
# Load the raw data needed later for finding anomalies
# Keep these variables accessible
print("Loading initial train/test splits...")
# Use try-except to handle potential network errors during fetch
try:
    newsgroups_train_raw = fetch_20newsgroups(subset="train")
    newsgroups_test_raw = fetch_20newsgroups(subset="test")
    print(f"Raw train posts: {len(newsgroups_train_raw.data)}")
    print(f"Raw test posts: {len(newsgroups_test_raw.data)}")
    print(f"All categories: {newsgroups_train_raw.target_names}")
except Exception as e:
    print(f"ERROR loading dataset: {e}")
    print("Please ensure internet is enabled for the notebook and try again.")
    # Optionally raise the error to stop execution if data loading is critical
    raise e

Loading initial train/test splits...
Raw train posts: 11314
Raw test posts: 7532
All categories: ['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x', 'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med', 'sci.space', 'soc.religion.christian', 'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc']


## Prepare the Dataset

### **Objective**

* **Preprocessing**: Clean the raw newsgroup posts.
* **Normalization**: Format the text to resemble standard prose.

### **Processing Steps**

* **Extract**: Use email headers (e.g., "Subject") and the message payload.
* **Remove**: Eliminate email addresses and common headers/footers.
* **Truncate**: Limit the text length (e.g., 5000 characters).

In [18]:
def preprocess_newsgroup_row(data):
    """
    Processes a single email/newsgroup entry:
    - Extracts the subject and body.
    - Removes email addresses and common clutter.
    - Truncates text to 5,000 characters.

    Args:
        data (str): Raw email message as a string.

    Returns:
        str: Cleaned and truncated text.
    """
    # Parse the email message from the raw string format
    try:
        msg = email.message_from_string(data)
        # Extract subject and body text
        subject = msg['Subject'] if msg['Subject'] else ""
        payload = msg.get_payload() if msg.get_payload() else ""
        # Ensure payload is a string
        if isinstance(payload, list): # Handle multipart messages simply
            # Decode parts if necessary, assuming utf-8 or latin-1
            payload_parts = []
            for part in payload:
                try:
                    if hasattr(part, 'get_payload'):
                        p_content = part.get_payload(decode=True)
                        if p_content:
                             try:
                                  payload_parts.append(p_content.decode('utf-8'))
                             except UnicodeDecodeError:
                                  try:
                                       payload_parts.append(p_content.decode('latin-1'))
                                  except UnicodeDecodeError:
                                       payload_parts.append("[Undecodable Content]") # Placeholder
                        else:
                            payload_parts.append(str(part)) # Fallback for non-payload parts
                    else:
                         payload_parts.append(str(part)) # Fallback for non-message parts
                except Exception as part_e:
                     payload_parts.append(f"[Error processing part: {part_e}]")
            payload = "\n".join(payload_parts)

        elif isinstance(payload, bytes):
             try:
                   payload = payload.decode('utf-8')
             except UnicodeDecodeError:
                   try:
                        payload = payload.decode('latin-1')
                   except UnicodeDecodeError:
                        payload = "[Undecodable Bytes Payload]"


        text = f"{subject}\n\n{str(payload)}" # Ensure payload is string

    except Exception as e:
        # Handle potential parsing errors on malformed messages
        # print(f"Warning: Error parsing email data: {e}. Using raw data snippet.")
        text = str(data)[:5000] # Use raw data as fallback

    # Remove email addresses from the text
    text = re.sub(r"[\w\.-]+@[\w\.-]+", "", text)
    # Remove common header/footer lines (simple examples)
    text = re.sub(r'^\s*Lines: \d+\s*$', '', text, flags=re.MULTILINE)
    text = re.sub(r'^\s*Organization: .*\s*$', '', text, flags=re.MULTILINE)
    text = re.sub(r'^\s*From: .*\s*$', '', text, flags=re.MULTILINE)
    text = re.sub(r'^\s*Subject: .*\s*$', '', text, flags=re.MULTILINE)
    text = re.sub(r'^\s*Nntp-Posting-Host: .*\s*$', '', text, flags=re.MULTILINE)
    text = re.sub(r'^\s*Article-I\.D\.: .*\s*$', '', text, flags=re.MULTILINE)
    text = re.sub(r'^\s*Keywords: .*\s*$', '', text, flags=re.MULTILINE)
    text = re.sub(r'^\s*In article <.*> you write:\s*$', '', text, flags=re.MULTILINE)
    text = re.sub(r'wrote:$', '', text, flags=re.MULTILINE) # Common quote intro


    # Truncate text to 5,000 characters to limit processing size
    text = text.strip()[:5000]

    return text


def preprocess_newsgroup_data(newsgroup_dataset, apply_clean):
    """
    Converts the newsgroup dataset into a structured DataFrame:
    - Stores text and labels.
    - Cleans text using preprocess_newsgroup_row() if apply_clean is True.
    - Maps numeric labels to category names.

    Args:
        newsgroup_dataset (sklearn.utils.Bunch): Newsgroup dataset object.
        apply_clean (boolean) : Apply row cleaning

    Returns:
        pd.DataFrame: Preprocessed dataset with 'Text', 'Label', 'Class Name'.
    """
    # Convert dataset into a pandas DataFrame
    df = pd.DataFrame(
        {"Text": newsgroup_dataset.data, "Label": newsgroup_dataset.target}
    )
    if apply_clean:
        print(f"Applying text cleaning...")
        # Apply text cleaning to each entry
        df["Text"] = df["Text"].apply(preprocess_newsgroup_row)
        # Remove rows where text became empty after cleaning
        df = df[df["Text"].str.strip().astype(bool)]
        print(f"DataFrame shape after cleaning: {df.shape}")


    # Convert numerical labels to category names
    target_names = newsgroup_dataset.target_names
    if target_names:
         # Ensure label exists in target_names mapping
         df["Class Name"] = df["Label"].apply(lambda l: target_names[l] if 0 <= l < len(target_names) else "Unknown")
    else:
         df["Class Name"] = "Unknown"


    return df.reset_index(drop=True) # Reset index after potential row removal

#### create pandas  dataframe from training and test datasets 

In [19]:
# Create pandas dataframes from training and test datasets *with* preprocessing
# Check if raw data was loaded successfully
if 'newsgroups_train_raw' in locals() and 'newsgroups_test_raw' in locals():
    print("Preprocessing raw train data...")
    df_train_full = preprocess_newsgroup_data(newsgroups_train_raw, apply_clean=True)
    print("\nPreprocessing raw test data...")
    df_test_full = preprocess_newsgroup_data(newsgroups_test_raw, apply_clean=True)

    print("\nFull Train DataFrame head:")
    print(df_train_full.head())
    print(f"\nFull Train DataFrame shape: {df_train_full.shape}")
    print(f"Full Test DataFrame shape: {df_test_full.shape}")

else:
    print("ERROR: Raw data not loaded earlier: Cannot preprocess.")
    # Create empty dataframes to avoid subsequent errors, but notebook won't work
    df_train_full = pd.DataFrame(columns=['Text', 'Label', 'Class Name'])
    df_test_full = pd.DataFrame(columns=['Text', 'Label', 'Class Name'])

Preprocessing raw train data...
Applying text cleaning...
DataFrame shape after cleaning: (11314, 2)

Preprocessing raw test data...
Applying text cleaning...
DataFrame shape after cleaning: (7532, 2)

Full Train DataFrame head:
                                                Text  Label  \
0  WHAT car is this!?\n\n I was wondering if anyo...      7   
1  SI Clock Poll - Final Call\n\nA fair number of...      4   
2  PB questions...\n\nwell folks, my mac plus fin...      4   
3  Re: Weitek P9000 ?\n\nRobert J.C. Kyanko () \n...      1   
4  Re: Shuttle Launch Question\n\nFrom article <>...     14   

              Class Name  
0              rec.autos  
1  comp.sys.mac.hardware  
2  comp.sys.mac.hardware  
3          comp.graphics  
4              sci.space  

Full Train DataFrame shape: (11314, 3)
Full Test DataFrame shape: (7532, 3)


#### Sampling Data to Define "Normal" Class ( to differentiate from Anamoly)

*  We sample a subset of the data to represent our "normal" dataset for anomaly detection.
*  Here, we choose the `sci` categories.
*   We also sample the test set similarly to have a baseline of expected 'normal' test points.

In [20]:
# Function to sample data
def sample_data(df, num_samples_per_class, classes_to_keep_pattern):
    """
    Samples rows from the dataset based on the specified number of samples per label
    and filters the dataset to keep only specified categories matching a pattern. Handles cases where classes have fewer samples than requested.

    Args:
        df (pd.DataFrame): Input dataframe containing the data ('Text', 'Label', 'Class Name').
        num_samples_per_class (int): Max number of samples to take per label.
        classes_to_keep_pattern (str): Substring pattern to filter class names.

    Returns:
        pd.DataFrame: Filtered and sampled dataframe. Returns empty if no matching classes or input df is empty.
    """
    if df.empty:
         # print("Debug: Input DataFrame to sample_data is empty.") # Removed Debug
         return df

    # Filter rows based on the pattern first
    df['Class Name'] = df['Class Name'].astype(str)
    try:
        filter_mask = df["Class Name"].str.contains(classes_to_keep_pattern, na=False, regex=False)
        df_filtered = df[filter_mask].copy()
    except Exception as filter_e:
        print(f"Error filtering DataFrame by pattern '{classes_to_keep_pattern}': {filter_e}")
        return pd.DataFrame(columns=df.columns) # Return empty on error


    if df_filtered.empty:
        print(f"Warning: No classes found containing pattern '{classes_to_keep_pattern}' AFTER filtering.")
        return df_filtered

    print(f"Found classes containing '{classes_to_keep_pattern}': {df_filtered['Class Name'].unique().tolist()}")
    print(f"Number of samples before sampling (after filtering): {len(df_filtered)}")


    # Sample rows, selecting num_samples_per_class of each remaining label
    try:
        # Ensure the sampling function handles empty groups if any arise
        df_sampled = (
            df_filtered.groupby("Class Name", group_keys=False)
            # Handle potential deprecation warning by explicitly selecting columns if needed,
            # but standard apply should work. Add random_state for reproducibility.
            .apply(lambda x: x.sample(min(num_samples_per_class, x.shape[0]), random_state=42) if not x.empty else None)
        )
        # Remove potential None results if a group was empty
        if df_sampled is not None:
             # Check if df_sampled is a DataFrame before calling dropna
             if isinstance(df_sampled, pd.DataFrame):
                  df_sampled.dropna(inplace=True)
             else: # If apply returned something else (like Series if only one group)
                  print(f"Warning: Unexpected result type from groupby.apply: {type(df_sampled)}")
                  # Attempt to convert back to DataFrame or handle appropriately
                  # For simplicity, return empty if structure is unexpected
                  df_sampled = pd.DataFrame(columns=df_filtered.columns)

        else: # Handle case where apply returns None (e.g., only one empty group)
             df_sampled = pd.DataFrame(columns=df_filtered.columns)


    except Exception as e:
         print(f"Error during sampling: {e}")
         return pd.DataFrame(columns=df.columns) # Return empty df on error

    print(f"Number of samples after sampling: {len(df_sampled)}")
    return df_sampled.reset_index(drop=True) # Reset index after sampling

# --- Define constants (keep as before) ---
TRAIN_NUM_SAMPLES_NORMAL = 100
TEST_NUM_SAMPLES_NORMAL = 25
NORMAL_CLASSES_PATTERN = "sci" # Keep classes containing 'sci' (sci.crypt, sci.electronics, sci.med, sci.space)

# Initialize empty DataFrames for results
df_train = pd.DataFrame(columns=['Text', 'Label', 'Class Name'])
df_test = pd.DataFrame(columns=['Text', 'Label', 'Class Name'])

# --- Sample 'normal' training data ---
print("\n--- Preparing to sample TRAINING data ---")
if 'df_train_full' in locals() and not df_train_full.empty:
    print(f"\nSampling 'normal' training data (pattern: {NORMAL_CLASSES_PATTERN})...")
    df_train = sample_data(df_train_full, TRAIN_NUM_SAMPLES_NORMAL, NORMAL_CLASSES_PATTERN)
    if not df_train.empty:
         print(f"\n'Normal' Training samples per class:\n{df_train['Class Name'].value_counts()}")
    else:
         print("Resulting 'normal' training DataFrame (df_train) is empty after sampling.")
else:
    print("Skipping training data sampling as df_train_full is empty or not defined.")


# --- Sample 'normal' test data ---
print("\n--- Preparing to sample TEST data ---")
if 'df_test_full' in locals() and not df_test_full.empty:
    print(f"\nSampling 'normal' test data (pattern: {NORMAL_CLASSES_PATTERN})...")
    df_test = sample_data(df_test_full, TEST_NUM_SAMPLES_NORMAL, NORMAL_CLASSES_PATTERN)
    if not df_test.empty:
         print(f"\n'Normal' Test samples per class:\n{df_test['Class Name'].value_counts()}")
    else:
         print("Resulting 'normal' test DataFrame (df_test) is empty after sampling.")

else:
     print("Skipping test data sampling as df_test_full is empty or not defined.")




--- Preparing to sample TRAINING data ---

Sampling 'normal' training data (pattern: sci)...
Found classes containing 'sci': ['sci.space', 'sci.med', 'sci.electronics', 'sci.crypt']
Number of samples before sampling (after filtering): 2373
Number of samples after sampling: 400

'Normal' Training samples per class:
Class Name
sci.crypt          100
sci.electronics    100
sci.med            100
sci.space          100
Name: count, dtype: int64

--- Preparing to sample TEST data ---

Sampling 'normal' test data (pattern: sci)...
Found classes containing 'sci': ['sci.med', 'sci.space', 'sci.crypt', 'sci.electronics']
Number of samples before sampling (after filtering): 1579
Number of samples after sampling: 100

'Normal' Test samples per class:
Class Name
sci.crypt          25
sci.electronics    25
sci.med            25
sci.space          25
Name: count, dtype: int64


  .apply(lambda x: x.sample(min(num_samples_per_class, x.shape[0]), random_state=42) if not x.empty else None)
  .apply(lambda x: x.sample(min(num_samples_per_class, x.shape[0]), random_state=42) if not x.empty else None)


## Create the Embeddings

Generate embeddings for each piece of text using the Gemini API embeddings endpoint. 

### Task types

The `text-embedding-004` model supports a `task_type` parameter that generates embeddings tailored for the specific task. For general similarity or anomaly detection based on topic, `RETRIEVAL_DOCUMENT` or `CLUSTERING` might also be suitable, but we'll stick with `CLASSIFICATION` as used previously, assuming it captures general semantic meaning well.

In [21]:
from google.api_core import retry  # Import retry mechanism for handling API errors
import tqdm  # Import tqdm for progress bars
from tqdm.rich import tqdm as tqdmr  # Rich progress bars for better visualization
import warnings  # Suppress warnings where necessary

# Add tqdm to Pandas for progress tracking in DataFrame operations
tqdmr.pandas()

# Suppress experimental warnings from tqdm library
warnings.filterwarnings("ignore", category=tqdm.TqdmExperimentalWarning)

# Define a helper function to retry API calls when quota limits are reached
is_retriable = lambda e: (isinstance(e, genai.errors.APIError) and e.code in {429, 503})

@retry.Retry(predicate=is_retriable, timeout=300)  # Retry on specific API errors with a timeout of 300 seconds
def embed_fn(text: str) -> list[float]:
    """
    Generates embeddings for a given text using an embedding model.

    Args:
        text (str): Input text to generate embeddings for.

    Returns:
        list[float]: Embedding vector as a list of floats.
    """
    response = client.models.embed_content(
        model="models/text-embedding-004",  # Specify the embedding model to use
        contents=text,  # Input text content from DF text column
        config=types.EmbedContentConfig(
            task_type="classification",  # Specify task type (e.g., classification)
        ),
    )

    return response.embeddings[0].values  # Return the embedding vector

def create_embeddings(df):
    """
    Adds an 'Embeddings' column to the DataFrame by generating embeddings from text.

    Args:
        df (pd.DataFrame): Input DataFrame with a 'Text' column.

    Returns:
        pd.DataFrame: DataFrame with an additional 'Embeddings' column.
    """
    df["Embeddings"] = df["Text"].progress_apply(embed_fn)  # Apply embedding generation with progress tracking
    return df


In [22]:
df_train = create_embeddings(df_train)
df_test = create_embeddings(df_test)

Output()

Output()

## Anomaly Detection Setup

Now we set up the anomaly detection task:
1.  **Create Synthetic Anomalies**: Define text examples representing categories completely different from the 'normal' (`sci.*`) newsgroup data, such as financial transactions and spam emails.
2.  **Embed Anomalies**: Generate embeddings for these synthetic samples using the parallel embedding function.
3.  **Combine**: Add the synthetic anomalous samples to the 'normal' test set (`df_test`).
4.  **Calculate Distances**: Compute the distance (e.g., cosine distance) of each point in the combined test set from the centroid (average embedding) of the 'normal' training set (`df_train`).
5.  **Identify Outliers**: Flag points with distances exceeding a threshold as anomalies. These should ideally be our synthetic examples.

In [23]:
# Ensure df_train and df_test are not empty before proceeding
if df_train.empty or df_test.empty:
     print("ERROR: 'Normal' train or test DataFrame is empty. Cannot proceed with anomaly detection.")
else:
     # --- 1. Create Synthetic Anomalous Samples ---
     print("Creating synthetic anomaly samples...")

     # Define synthetic texts for different categories
     # 5 Financial examples
     financial_texts = [
          "URGENT: Your account statement for March is ready. Payment due April 15th. Click here to view details.",
          "Stock Alert: ACME Corp (ACME) up 5% pre-market trading following positive earnings report.",
          "Transaction Confirmation: $150.75 paid to 'OnlineRetailer'. Your new balance is $1,234.56.",
          "Loan Application Update: Your mortgage pre-approval has been processed. A loan officer will contact you shortly.",
          "Investment Opportunity: Learn about our new high-yield savings account with competitive APY rates."
     ]
     # 5 Spam examples
     spam_texts = [
          "Congratulations! You've won a FREE iPhone 15! Click HERE to claim your prize NOW!",
          "VIAGRA special offer!! Cheap pills online, discrete shipping guaranteed. Limited time only!",
          "Urgent account security warning! Verify your login details immediately by clicking this link: http://totally-not-a-scam-site.com",
          "Meet hot singles in your area tonight! Join free now, easy registration, instant matches!",
          "Make $1000s working from home! Easy online job opportunity, no experience needed. Apply today!"
     ]

     synthetic_texts = financial_texts + spam_texts
     synthetic_labels = (['Financial'] * 5) + (['Spam'] * 5)
     num_anomalies_to_sample = len(synthetic_texts) # Should be 10

     # Create DataFrame for anomalies
     df_anomalies = pd.DataFrame({
          'Text': synthetic_texts,
          'Class Name': synthetic_labels,
          # Add dummy Label if needed by downstream code, though not strictly necessary for detection
          'Label': [-1] * num_anomalies_to_sample # Assign a dummy label like -1
     })

     print(f"Created {num_anomalies_to_sample} synthetic anomaly samples.")
     print(f"Synthetic classes: {df_anomalies['Class Name'].unique().tolist()}")


     # --- 2. Generate Embeddings for Anomalies ---
     print("\nGenerating embeddings for synthetic anomalies...")
     df_anomalies = create_embeddings(df_anomalies) # Uses 'Text' column by default

     # Check if embeddings were generated successfully for anomalies
     if 'Embeddings' not in df_anomalies.columns or df_anomalies.empty:
          print("Error: Failed to generate embeddings for synthetic anomaly samples. Cannot proceed.")
          # Handle error
     else:
          print(f"Successfully generated embeddings for {len(df_anomalies)} synthetic anomaly samples.")

          # --- 3. Combine Anomalies with Test Set ---
          # Add a flag to distinguish anomalies
          df_anomalies['Is_Anomaly'] = True
          # Add flag to original test set (ensure it hasn't been added before)
          if 'Is_Anomaly' not in df_test.columns:
               df_test['Is_Anomaly'] = False
          else: # Ensure existing flags are False for the normal test set
              df_test['Is_Anomaly'] = False


          # Combine the original test set with the new anomalies
          df_test_combined = pd.concat([df_test, df_anomalies], ignore_index=True)
          print(f"\nCombined test set size: {len(df_test_combined)} rows")

          # --- 4. Detect Anomalies using Distance from Training Centroid ---
          print("Calculating training centroid and distances...")

          # Ensure training embeddings are ready
          if 'Embeddings' not in df_train.columns or df_train.empty:
               print("Error: Training embeddings not available. Cannot calculate centroid.")
          else:
                x_train_embeddings = np.stack(df_train['Embeddings'].values)
                # Ensure combined test embeddings are ready (check after concat might be needed if anomalies failed)
                if 'Embeddings' not in df_test_combined.columns or df_test_combined['Embeddings'].isnull().any():
                     print("Warning: Missing embeddings in combined test set after concat. Dropping rows with missing embeddings.")
                     df_test_combined.dropna(subset=['Embeddings'], inplace=True)
                     print(f"Proceeding with {len(df_test_combined)} rows in combined test set.")


                if not df_test_combined.empty: # Check if still have data after potential drop
                     x_test_combined_embeddings = np.stack(df_test_combined['Embeddings'].values)

                     # Calculate the centroid (mean embedding) of the 'normal' training data
                     train_centroid = np.mean(x_train_embeddings, axis=0)

                     # Calculate cosine distance from each point in the combined test set to the training centroid
                     distances = cosine_distances(x_test_combined_embeddings, train_centroid.reshape(1, -1))
                     df_test_combined['Distance_to_Centroid'] = distances.flatten()

                     # --- 5. Identify & Save Anomalies ---
                     # Determine a threshold (e.g., 90th percentile of distances)
                     valid_distances = df_test_combined['Distance_to_Centroid'].dropna()
                     if not valid_distances.empty:
                          # Calculate percentile based on expected number of anomalies
                          # num_total = len(df_test_combined)
                          # anomaly_percentile = (1 - (num_anomalies_to_sample / num_total)) * 100 if num_total > 0 else 90
                          # Using fixed 90th percentile for simplicity
                          anomaly_percentile = 90
                          distance_threshold = np.percentile(valid_distances, anomaly_percentile)
                          print(f"Using distance threshold ({anomaly_percentile:.0f}th percentile): {distance_threshold:.4f}")

                          # Flag potential anomalies based on the threshold
                          df_test_combined['Detected_Anomaly'] = df_test_combined['Distance_to_Centroid'] > distance_threshold

                          # Separate the detected anomalies
                          detected_anomalies_df = df_test_combined[df_test_combined['Detected_Anomaly'] == True].copy() # Explicit check for True

                          print("\n--- Detected Anomalies ---")
                          if not detected_anomalies_df.empty:
                               # Display info about detected anomalies
                               print(detected_anomalies_df[['Text', 'Class Name', 'Is_Anomaly', 'Distance_to_Centroid']].head())

                               # Compare actual injected vs detected
                               # Recalculate num_anomalies_present in the potentially filtered df_test_combined
                               num_injected_present = df_test_combined['Is_Anomaly'].sum()
                               correctly_detected = detected_anomalies_df['Is_Anomaly'].sum()
                               false_positives = len(detected_anomalies_df) - correctly_detected
                               print(f"\nTotal detected: {len(detected_anomalies_df)}")
                               print(f"Correctly identified injected anomalies: {correctly_detected}/{num_injected_present} (Expected {num_anomalies_to_sample} injected initially)")
                               print(f"Incorrectly identified (false positives): {false_positives}")

                               # Save the detected anomalies to a CSV
                               try:
                                    anomaly_filename = 'detected_synthetic_anomalies.csv'
                                    # Select columns to save
                                    columns_to_save = ['Text', 'Class Name', 'Is_Anomaly', 'Distance_to_Centroid', 'Detected_Anomaly', 'Embeddings']
                                    detected_anomalies_df[columns_to_save].to_csv(anomaly_filename, index=False)
                                    print(f"\nDetected anomalies saved to {anomaly_filename}")
                               except Exception as e:
                                    print(f"\nError saving anomalies to CSV: {e}")
                          else:
                               print("No anomalies detected above the threshold.")
                     else:
                          print("Error: No valid distances found to calculate threshold.")

                     # Display the head of the combined test df showing new columns
                     print("\n--- Head of Combined Test Set with Anomaly Info ---")
                     print(df_test_combined[['Text', 'Class Name', 'Is_Anomaly', 'Distance_to_Centroid', 'Detected_Anomaly']].head())
                else:
                    print("Combined test set is empty after handling missing embeddings.")

Output()

Creating synthetic anomaly samples...
Created 10 synthetic anomaly samples.
Synthetic classes: ['Financial', 'Spam']

Generating embeddings for synthetic anomalies...


Successfully generated embeddings for 10 synthetic anomaly samples.

Combined test set size: 110 rows
Calculating training centroid and distances...
Using distance threshold (90th percentile): 0.1956

--- Detected Anomalies ---
                                                  Text Class Name  Is_Anomaly  \
51   cure for dry skin?\n\nHi all,\n\nMy skin is ve...    sci.med       False   
100  URGENT: Your account statement for March is re...  Financial        True   
101  Stock Alert: ACME Corp (ACME) up 5% pre-market...  Financial        True   
102  Transaction Confirmation: $150.75 paid to 'Onl...  Financial        True   
103  Loan Application Update: Your mortgage pre-app...  Financial        True   

     Distance_to_Centroid  
51               0.225113  
100              0.323991  
101              0.235453  
102              0.349619  
103              0.269010  

Total detected: 11
Correctly identified injected anomalies: 10/10 (Expected 10 injected initially)
Incorrectly ident