In [5]:
import numpy as np
import json
import requests
from sklearn.linear_model import LogisticRegression
import os

dataset_filename = "label_flipping_dataset.npz"
random_seed = 1337  # Seed for reproducibility in attack & model training
np.random.seed(random_seed)  # Apply seed globally if needed, or pass to functions

# >>> IMPORTANT: SET THIS VARIABLE TO YOUR SPAWNED INSTANCE IP AND PORT<<<
evaluator_base_url = "http://<IP>:<PORT>"  # CHANGE THIS
# Example: evaluator_base_url = "http://10.10.10.1:5555"

# Attack Configuration
TARGET_CLASS_TO_POISON = 0  # We want to make the model bad at identifying Class 0
NEW_LABEL_FOR_POISONED = 1  # We want it to predict Class 1 instead
POISON_FRACTION = 0.00

# Load Data
print(f"Loading data from: {dataset_filename}")
try:
    data = np.load(dataset_filename)
    X_train = data["Xtr"]
    y_train = data["ytr"]
    X_test = data["Xte"]
    y_test = data["yte"]
    print("Data loaded successfully.")
    print(f"X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
    unique_classes_train = np.unique(y_train)
    print(f"Unique classes in training data: {unique_classes_train}")
    if (
        TARGET_CLASS_TO_POISON not in unique_classes_train
        or NEW_LABEL_FOR_POISONED not in unique_classes_train
    ):
        print("Warning: Target or new label class not found in training data.")
    data.close()
except FileNotFoundError:
    print(f"Error: Dataset file '{dataset_filename}' not found.")
    raise
except KeyError as e:
    print(f"Error: Could not find expected array key '{e}' in the .npz file.")
    raise
except Exception as e:
    print(f"An unexpected error occurred during data loading: {e}")
    raise

Loading data from: label_flipping_dataset.npz
Data loaded successfully.
X_train shape: (700, 2), y_train shape: (700,)
Unique classes in training data: [0 1]


In [10]:
def targeted_class_label_flip(y_train, target_class, new_class, poison_fraction, seed):
    if not (0 <= poison_fraction <= 1):
        raise ValueError("Poison fraction must be between 0 and 1.")
    if target_class == new_class:
        raise ValueError("Target class and new label must be different.")
    unique_labels = np.unique(y_train)
    if target_class not in unique_labels or new_class not in unique_labels:
        raise ValueError("Target class or new label not found in training labels.")
    target_indices = np.where(y_train == target_class)[0]
    n_target_samples = len(target_indices)

    if n_target_samples == 0:
        print(f"No samples found for target class {target_class}. No poisoning applied.")
        return y_train.copy(), np.array([], dtype=int)
    
    n_poison_samples = int(n_target_samples * poison_fraction)
    if n_poison_samples == 0:
        print("Poison fraction is too low; no samples will be flipped.")
        return y_train.copy(), np.array([], dtype=int)
    
    rng_instance = np.random.default_rng(seed)
    indices_within_target_set_to_flip = rng_instance.choice(n_target_samples, size=n_poison_samples, replace=False)
    flipped_indices = target_indices[indices_within_target_set_to_flip]

    y_train_poisoned = y_train.copy()
    y_train_poisoned[flipped_indices] = new_class
    
    return y_train_poisoned, flipped_indices

In [12]:
# Execute the attack
y_train_poisoned, flipped_idx = targeted_class_label_flip(
    y_train,
    target_class=TARGET_CLASS_TO_POISON,
    new_class=NEW_LABEL_FOR_POISONED,
    poison_fraction=POISON_FRACTION,
    seed=random_seed,
)

# Basic Checks
print("\n--- Post-Attack Checks ---")
if flipped_idx.size > 0:
    print(f"Attack function executed, {len(flipped_idx)} label(s) flipped.")
    print(f"Indices of flipped labels in training data (first 10): {flipped_idx[:10]}")
    print(f"Original labels at flipped indices (first 10): {y_train[flipped_idx[:10]]}")
    print(
        f"Poisoned labels at flipped indices (first 10): {y_train_poisoned[flipped_idx[:10]]}"
    )
    print(f"Shape of poisoned labels array: {y_train_poisoned.shape}")
else:
    print(
        "Attack function ran, but no labels were flipped (check settings and warnings)."
    )
    print("Proceeding with potentially unpoisoned labels.")

Poison fraction is too low; no samples will be flipped.

--- Post-Attack Checks ---
Proceeding with potentially unpoisoned labels.


In [None]:
# %%
# Train Model using Logistic Regression (Same as before)
print("\n--- Training Model on Poisoned Labels ---")
model = LogisticRegression(random_state=random_seed, solver="liblinear")

try:
    # Train on original features but poisoned labels
    model.fit(X_train, y_train_poisoned)
    print("Logistic Regression model trained successfully.")
except Exception as e:
    print(f"Error during model training: {e}")
    raise

In [None]:
print("\n--- Extracting Model Parameters ---")
try:
    weights = model.coef_
    intercept = model.intercept_
    print(f"Extracted weights shape: {weights.shape}")
    print(f"Extracted intercept shape: {intercept.shape}")
    weights_list = weights.tolist()
    intercept_list = intercept.tolist()
    parameters_extracted = True
except Exception as e:
    print(f"An unexpected error occurred during parameter extraction: {e}")
    weights_list = None
    intercept_list = None
    parameters_extracted = False

In [None]:
health_check_url = f"{evaluator_base_url}/health"
print(f"Checking evaluator health at: {health_check_url}")
if "<EVALUATOR_IP>" in evaluator_base_url:
    print("\n--- WARNING ---")
    print(
        "Please update the 'evaluator_base_url' variable with the correct IP and Port before running!"
    )
    print("-------------")
else:
    try:
        response = requests.get(health_check_url, timeout=10)
        response.raise_for_status()
        health_status = response.json()
        print("\n--- Health Check Response ---")
        print(f"Status: {health_status.get('status', 'N/A')}")
        print(f"Message: {health_status.get('message', 'No message received.')}")
        if health_status.get("status") != "healthy":
            print(
                "\nWarning: Evaluator service reported an unhealthy status. It might still be starting up or encountered an issue (like loading data)."
            )
    except requests.exceptions.ConnectionError as e:
        print(f"\nConnection Error: Could not connect to {health_check_url}.")
        print("Please check:")
        print("  1. The evaluator URL (IP address and port) is correct.")
        print("  2. The evaluator Docker container is running.")
        print(
            "  3. There are no network issues (firewalls, etc.) blocking the connection."
        )
    except requests.exceptions.Timeout:
        print(f"\nTimeout Error: The request to {health_check_url} timed out.")
        print(
            "The server might be taking too long to respond or there could be network issues."
        )
    except requests.exceptions.RequestException as e:
        print(f"\nError during health check request: {e}")
        print("Check the URL format and ensure the server is running.")
    except json.JSONDecodeError:
        print("\nError: Could not decode JSON response from health check.")
        print("The server might have sent an invalid response.")
        print(
            f"Raw response status: {response.status_code}, Raw response text: {response.text}"
        )
    except Exception as e:
        print(f"\nAn unexpected error occurred during health check: {e}")

In [None]:
evaluator_url = f"{evaluator_base_url}/evaluate_targeted"
print(f"\nAttempting submission to: {evaluator_url}")

if not parameters_extracted:
    print("Error: Cannot submit - parameters not extracted.")
elif "<EVALUATOR_IP>" in evaluator_base_url or "<PORT>" in evaluator_base_url:
    print("\n--- WARNING: Update evaluator_base_url ---")
else:
    payload = {"coef": weights_list, "intercept": intercept_list}
    print(f"Payload preview: {json.dumps(payload, indent=2)}")

    try:
        response = requests.post(evaluator_url, json=payload, timeout=30)
        response.raise_for_status()
        result = response.json()

        print("\n--- Evaluator Response ---")
        if result.get("success"):
            print(f"{'=' * 10} Attack Successful! {'=' * 10}")
            oa_str = (
                f"{result.get('overall_accuracy', 'N/A'):.4f}"
                if isinstance(result.get("overall_accuracy"), (int, float))
                else "N/A"
            )
            c0a_str = (
                f"{result.get('class0_accuracy', 'N/A'):.4f}"
                if isinstance(result.get("class0_accuracy"), (int, float))
                else "N/A"
            )  # Get Class 0 Accuracy
            print(f"Overall Accuracy evaluated: {oa_str}")
            print(f"Accuracy on Class 0 samples: {c0a_str}")  # Display Class 0 Accuracy
            print(f"Message: {result.get('message', 'N/A')}")
            print(f"\nFLAG: {result.get('flag')}")
            print(f"{'=' * 38}")
        else:
            print("Evaluation Failed.")
            oa_val = result.get("overall_accuracy")
            c0a_val = result.get("class0_accuracy")  # Get Class 0 Accuracy
            oa_str = f"{oa_val:.4f}" if oa_val is not None else "N/A"
            c0a_str = (
                f"{c0a_val:.4f}" if c0a_val is not None else "N/A"
            )  # Get Class 0 Accuracy

            print(f"Overall Accuracy evaluated: {oa_str}")
            print(f"Accuracy on Class 0 samples: {c0a_str}")  # Display Class 0 Accuracy
            print(f"Message: {result.get('message', 'No message provided.')}")
            print(
                "\nHints: Did the attack significantly reduce accuracy specifically for Class 0 samples?"
            )
            print("Did the overall accuracy remain above the required threshold?")
            print("Consider adjusting the POISON_FRACTION.")

    except requests.exceptions.ConnectionError:
        print(f"\nConnection Error: Could not connect to {evaluator_url}.")
    except requests.exceptions.Timeout:
        print(f"\nTimeout Error: Request to {evaluator_url} timed out.")
    except requests.exceptions.RequestException as e:
        print(f"\nError during submission request: {e}")
        if e.response is not None:
            print(f"Server Response Status Code: {e.response.status_code}")
            try:
                print(f"Server Response Body: {e.response.json()}")
            except json.JSONDecodeError:
                print(f"Server Response Body (non-JSON): {e.response.text}")
    except Exception as e:
        print(f"\nAn unexpected error occurred during submission: {e}")