In [None]:
# ------------------------------------------------------
# DEFENSE MED ISOLATION FOREST
# Detekterar och tar bort misst√§nkta poisoned exempel
# ------------------------------------------------------

from datasets import load_dataset, Dataset
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, DistilBertModel
from transformers import TrainingArguments, Trainer
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from sklearn.ensemble import IsolationForest
import torch
import csv, os, random


# -------------------------------------------------------------------------
# 1. Save results funktion
# -------------------------------------------------------------------------

def save_results(attack_type, attack_rate, accuracy, f1, train_size, confusion_matrix, 
                 defense_used=None, removed_count=None, filename="results/logs/defense_flip.csv"):
    
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    
    file_exists = os.path.isfile(filename)

    with open(filename, mode="a", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)

        if not file_exists:
            writer.writerow(["attack_type", "attack_rate", "accuracy", "f1", 
                           "train_size", "confusion_matrix", "defense_used", "removed_count"])

        writer.writerow([
            attack_type,
            attack_rate,
            accuracy,
            f1,
            train_size,
            confusion_matrix.tolist(),
            defense_used,
            removed_count
        ])

    print(f"‚úî Resultat sparat i {filename}")


# -------------------------------------------------------------------------
# 2. Flip-label funktion (samma som tidigare)
# -------------------------------------------------------------------------

def flip_labels(dataset, percentage=0.1):
    """Flips percentage of labels (1 ‚Üí 0, 0 ‚Üí 1)."""
    n = len(dataset)
    k = int(n * percentage)

    poisoned = dataset.select(range(n))
    flip_idx = random.sample(range(n), k)

    def flip(example, idx):
        lbl = example["label"]
        if idx in flip_idx:
            example["label"] = 1 - lbl
        return example

    poisoned = poisoned.map(flip, with_indices=True)
    return poisoned, flip_idx


# -------------------------------------------------------------------------
# 3. Extrahera embeddings fr√•n DistilBERT
# -------------------------------------------------------------------------

def extract_embeddings(dataset, tokenizer, model_name="distilbert-base-uncased", device="cpu"):
    """
    Extraherar CLS-token embeddings fr√•n DistilBERT f√∂r varje exempel.
    """
    model = DistilBertModel.from_pretrained(model_name).to(device)
    model.eval()
    
    embeddings = []
    
    print(f"Extraherar embeddings f√∂r {len(dataset)} exempel...")
    
    with torch.no_grad():
        for i, example in enumerate(dataset):
            if i % 100 == 0:
                print(f"  {i}/{len(dataset)}", end="\r")
            
            inputs = tokenizer(
                example["text"], 
                return_tensors="pt", 
                truncation=True, 
                padding="max_length", 
                max_length=256
            ).to(device)
            
            outputs = model(**inputs)
            # Anv√§nd CLS token (f√∂rsta token) som representation
            cls_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
            embeddings.append(cls_embedding[0])
    
    print(f"\n‚úî Embeddings extraherade: {len(embeddings)}")
    return np.array(embeddings)


# -------------------------------------------------------------------------
# 4. IsolationForest f√∂r att detektera outliers
# -------------------------------------------------------------------------

def detect_outliers(embeddings, contamination=0.1):
    """
    Anv√§nder IsolationForest f√∂r att identifiera outliers.
    contamination: f√∂rv√§ntad andel outliers (motsvarar poison rate)
    """
    print(f"\nK√∂r IsolationForest med contamination={contamination}")
    
    iso_forest = IsolationForest(
        contamination=contamination,
        random_state=42,
        n_estimators=100
    )
    
    predictions = iso_forest.fit_predict(embeddings)
    # -1 = outlier, 1 = inlier
    outlier_indices = np.where(predictions == -1)[0]
    
    print(f"‚úî Detekterade {len(outlier_indices)} outliers")
    return outlier_indices


# -------------------------------------------------------------------------
# 5. Ta bort outliers fr√•n dataset
# -------------------------------------------------------------------------

def remove_outliers(dataset, outlier_indices):
    """Tar bort exempel p√• specificerade index."""
    all_indices = set(range(len(dataset)))
    keep_indices = sorted(list(all_indices - set(outlier_indices)))
    
    cleaned_dataset = dataset.select(keep_indices)
    print(f"‚úî Dataset rensat: {len(dataset)} ‚Üí {len(cleaned_dataset)} exempel")
    
    return cleaned_dataset


# -------------------------------------------------------------------------
# 6. HUVUDEXPERIMENT: Label Flipping Defense
# -------------------------------------------------------------------------

ATTACK_RATE = 0.10  # √Ñndra detta f√∂r olika experiment
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {DEVICE}")
print(f"Attack rate: {ATTACK_RATE * 100}%\n")

# Ladda dataset
dataset = load_dataset("imdb")

train = dataset["train"].shuffle(seed=42).select(range(500))
val   = dataset["test"].shuffle(seed=42).select(range(250))
test  = dataset["test"].shuffle(seed=42).select(range(250))

print("Dataset loaded:", len(train), len(val), len(test))

# Skapa poisoned data
poisoned_train, flipped_idx = flip_labels(train, percentage=ATTACK_RATE)
print(f"Antal flippade exempel: {len(flipped_idx)}")

# Extrahera embeddings
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
embeddings = extract_embeddings(poisoned_train, tokenizer, device=DEVICE)

# Detektera outliers
outlier_indices = detect_outliers(embeddings, contamination=ATTACK_RATE)

# Analysera hur m√•nga poisoned exempel som f√•ngades
detected_poisoned = len(set(outlier_indices) & set(flipped_idx))
false_positives = len(set(outlier_indices) - set(flipped_idx))
missed_poisoned = len(set(flipped_idx) - set(outlier_indices))

print(f"\nüìä Detection Analysis:")
print(f"  Total poisoned examples: {len(flipped_idx)}")
print(f"  Detected outliers: {len(outlier_indices)}")
print(f"  True positives (poisoned detected): {detected_poisoned}")
print(f"  False positives (clean flagged): {false_positives}")
print(f"  False negatives (poisoned missed): {missed_poisoned}")
print(f"  Detection rate: {detected_poisoned/len(flipped_idx)*100:.1f}%")

# Rensa dataset
cleaned_train = remove_outliers(poisoned_train, outlier_indices)


# -------------------------------------------------------------------------
# 7. Tokenisering
# -------------------------------------------------------------------------

def tokenize(batch):
    return tokenizer(
        batch["text"],
        truncation=True,
        padding="max_length",
        max_length=256
    )

train_tok = cleaned_train.map(tokenize, batched=True)
val_tok   = val.map(tokenize, batched=True)
test_tok  = test.map(tokenize, batched=True)

train_tok = train_tok.rename_column("label", "labels")
val_tok   = val_tok.rename_column("label", "labels")
test_tok  = test_tok.rename_column("label", "labels")

train_tok.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
val_tok.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
test_tok.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])


# -------------------------------------------------------------------------
# 8. Tr√§na modell p√• rensad data
# -------------------------------------------------------------------------

model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")

def compute_metrics(pred):
    logits, labels = pred
    preds = np.argmax(logits, axis=-1)
    return {
        "accuracy": accuracy_score(labels, preds),
        "f1": f1_score(labels, preds)
    }

args = TrainingArguments(
    output_dir=f"defense_output_rate_{int(ATTACK_RATE * 100)}",
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,

    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=32,

    learning_rate=5e-5,
    weight_decay=0.01,
    warmup_ratio=0.1,
    seed=42
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_tok,
    eval_dataset=val_tok,
    compute_metrics=compute_metrics
)

print("\nüöÄ Tr√§nar modell p√• rensad data...")
trainer.train()


# -------------------------------------------------------------------------
# 9. Utv√§rdera modellen
# -------------------------------------------------------------------------

print("\nüìä Evaluating on test set...")
test_results = trainer.evaluate(test_tok)
print(test_results)

test_accuracy = test_results["eval_accuracy"]
test_f1 = test_results["eval_f1"]

# Confusion matrix
pred_out = trainer.predict(test_tok)
logits = pred_out.predictions
y_pred = np.argmax(logits, axis=-1)
y_true = pred_out.label_ids

cm = confusion_matrix(y_true, y_pred)
print("\nConfusion Matrix (defended):")
print(cm)


# -------------------------------------------------------------------------
# 10. Spara resultat
# -------------------------------------------------------------------------

save_results(
    attack_type="label_flip_defended",
    attack_rate=ATTACK_RATE,
    accuracy=test_accuracy,
    f1=test_f1,
    train_size=len(cleaned_train),
    confusion_matrix=cm,
    defense_used="IsolationForest",
    removed_count=len(outlier_indices)
)

print("\n‚úî DEFENSE EXPERIMENT KLAR!")
print(f"Final accuracy: {test_accuracy:.4f}")
print(f"Removed {len(outlier_indices)} examples ({len(outlier_indices)/len(poisoned_train)*100:.1f}%)")