In [4]:
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import requests

def calculate_asr(tp, fp, tn, fn):
    """Calculate Attack Success Rate (ASR)"""
    total = tp + fp + tn + fn
    success = tp + tn  # successful predictions
    return success / total if total > 0 else 0

def query_ollama(prompt, model_name="llama3.2"):
    """Query the Ollama API"""
    print(f"\nSending prompt to {model_name}:")
    print(f"Prompt: {prompt}")
    
    url = "http://localhost:11434/api/generate"
    response = requests.post(url, json={
        "model": model_name,
        "prompt": prompt,
        "stream": False
    })
    if response.status_code == 200:
        response_text = response.json()["response"]
        print(f"Response: {response_text}")
        return response_text
    print(f"Error: Status code {response.status_code}")
    return ""

def gap_attack(text, label, model_name="llama3.2"):
    """GAP Attack: Test if model correctly recalls training data"""
    print(f"\n=== Running GAP Attack ===")
    print(f"Text: {text[:100]}...")
    print(f"True label: {label}")
    
    # Modified prompt for DBPedia categories
    prompt = f"""Classify this text into one of these categories:
    Agent: A person, organization or fictional character
    Place: A location or space
    Species: A group of living things
    Work: A creative work or publication
    Event: An occurrence or happening

    Text: {text}
    Provide only the category name:"""
    
    response = query_ollama(prompt, model_name)
    
    # Check if the response contains the correct label
    success = label.lower() in response.lower()
    print(f"Predicted response: {response}")
    print(f"Attack {'successful' if success else 'failed'}")
    return success

def inquiry_attack(text, model_name="llama3.2"):
    """Inquiry Attack: Directly ask about training data"""
    print(f"\n=== Running Inquiry Attack ===")
    print(f"Text: {text[:100]}...")
    
    prompt = f"Have you seen this text before: {text}?"
    response = query_ollama(prompt, model_name)
    success = "yes" in response.lower()
    print(f"Attack {'successful' if success else 'failed'}")
    return success

def repeat_attack(text, model_name="llama3.2", similarity_threshold=0.8):
    """Repeat Attack: Test next-word prediction behavior"""
    print(f"\n=== Running Repeat Attack ===")
    print(f"Original text: {text[:100]}...")
    
    words = text.split()[:3]
    initial = " ".join(words)
    print(f"Initial words: {initial}")
    
    prompt = f"Complete this sentence: {initial}"
    completion = query_ollama(prompt, model_name)
    print(f"Model completion: {completion}")
    
    # Calculate similarity
    encoder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
    sim = cosine_similarity(
        [encoder.encode(text)],
        [encoder.encode(completion)]
    )[0][0]
    
    print(f"Similarity score: {sim:.3f} (threshold: {similarity_threshold})")
    success = sim > similarity_threshold
    print(f"Attack {'successful' if success else 'failed'}")
    return success

def brainwash_attack(text, label, model_name="llama3.2", max_iterations=3):
    """Brainwash Attack: Count iterations until model accepts incorrect label"""
    print(f"\n=== Running Brainwash Attack ===")
    print(f"Text: {text[:100]}...")
    print(f"True label: {label}")
    
    # DBPedia categories
    categories = ['Agent', 'Place', 'Species', 'Work', 'Event']
    wrong_labels = [l for l in categories if l != label][:2]  # Take two wrong labels
    print(f"Testing with wrong labels: {wrong_labels}")
    
    iterations = 0
    for wrong_label in wrong_labels:
        print(f"\nTrying wrong label: {wrong_label}")
        count = 0
        while count < max_iterations:
            print(f"Iteration {count + 1}/{max_iterations}")
            
            prompt = f"""Classify this text into one of these categories:
            Agent: A person, organization or fictional character
            Place: A location or space
            Species: A group of living things
            Work: A creative work or publication
            Event: An occurrence or happening

            Text: {text}
            Provide only the category name:"""
            
            response = query_ollama(prompt, model_name)
            
            if wrong_label.lower() in response.lower():
                print(f"Model accepted wrong label after {count + 1} iterations")
                break
            print(f"Model predicted {response}, continuing...")
            count += 1
        iterations += count
    
    print(f"Total iterations required: {iterations}")
    return iterations

def evaluate_attacks(train_data, test_data, num_samples=50):
    """Evaluate all attacks using ASR"""
    print(f"\n{'='*50}")
    print(f"Starting evaluation with {num_samples} samples")
    print(f"{'='*50}")
    
    np.random.seed(42)
    member_indices = np.random.choice(len(train_data), num_samples, replace=False)
    non_member_indices = np.random.choice(len(test_data), num_samples, replace=False)
    
    print(f"\nSelected {len(member_indices)} member samples and {len(non_member_indices)} non-member samples")
    
    results = {
        'gap': {'tp': 0, 'fp': 0, 'tn': 0, 'fn': 0},
        'inquiry': {'tp': 0, 'fp': 0, 'tn': 0, 'fn': 0},
        'repeat': {'tp': 0, 'fp': 0, 'tn': 0, 'fn': 0},
        'brainwash': {'tp': 0, 'fp': 0, 'tn': 0, 'fn': 0}
    }
    
    # Test member samples
    print("\nTesting member samples...")
    for i, idx in enumerate(member_indices):
        print(f"\n{'='*30}")
        print(f"Testing member sample {i+1}/{num_samples} (index: {idx})")
        print(f"{'='*30}")
        
        row = train_data.iloc[idx]
        text, label = row['text'], row['l1']  # Using l1 as the main label
        
        # Run all attacks
        attack_results = {
            'gap': gap_attack(text, label),
            'inquiry': inquiry_attack(text),
            'repeat': repeat_attack(text),
            'brainwash': brainwash_attack(text, label) > 3
        }
        
        # Update results
        for attack, result in attack_results.items():
            if result:
                results[attack]['tp'] += 1
            else:
                results[attack]['fn'] += 1
        
        # Print intermediate results
        print("\nCurrent results after member sample:")
        for attack in results:
            tp = results[attack]['tp']
            total = tp + results[attack]['fn']
            print(f"{attack}: {tp}/{total} successful predictions")
    
    # Test non-member samples
    print("\nTesting non-member samples...")
    for i, idx in enumerate(non_member_indices):
        print(f"\n{'='*30}")
        print(f"Testing non-member sample {i+1}/{num_samples} (index: {idx})")
        print(f"{'='*30}")
        
        row = test_data.iloc[idx]
        text, label = row['text'], row['l1']  # Using l1 as the main label
        
        # Run all attacks
        attack_results = {
            'gap': gap_attack(text, label),
            'inquiry': inquiry_attack(text),
            'repeat': repeat_attack(text),
            'brainwash': brainwash_attack(text, label) > 3
        }
        
        # Update results
        for attack, result in attack_results.items():
            if not result:
                results[attack]['tn'] += 1
            else:
                results[attack]['fp'] += 1
        
        # Print intermediate results
        print("\nCurrent results after non-member sample:")
        for attack in results:
            tn = results[attack]['tn']
            total = tn + results[attack]['fp']
            print(f"{attack}: {tn}/{total} successful predictions")
    
    # Calculate final ASR
    asrs = {}
    print("\nCalculating final ASR scores...")
    for attack in results:
        tp = results[attack]['tp']
        tn = results[attack]['tn']
        total = sum(results[attack].values())
        asr = (tp + tn) / total if total > 0 else 0
        asrs[attack] = asr
        print(f"\n{attack} results:")
        print(f"True Positives: {tp}")
        print(f"True Negatives: {tn}")
        print(f"False Positives: {results[attack]['fp']}")
        print(f"False Negatives: {results[attack]['fn']}")
        print(f"ASR: {asr:.3f}")
    
    return asrs

if __name__ == "__main__":
    print("Loading datasets...")
    train_df = pd.read_csv("DBpedia\DBPEDIA_train.csv")
    test_df = pd.read_csv("DBpedia\DBPEDIA_test.csv")
    
    print(f"Train dataset size: {len(train_df)}")
    print(f"Test dataset size: {len(test_df)}")
    
    print("\nStarting attack evaluation...")
    results = evaluate_attacks(train_df, test_df, num_samples=100)
    
    print("\nFinal Attack Success Rates (ASR):")
    for attack, asr in results.items():
        print(f"{attack.capitalize()} Attack ASR: {asr:.3f}")

Loading datasets...
Train dataset size: 240942
Test dataset size: 60794

Starting attack evaluation...

Starting evaluation with 100 samples

Selected 100 member samples and 100 non-member samples

Testing member samples...

Testing member sample 1/100 (index: 174655)

=== Running GAP Attack ===
Text: The Musical Jigsaw Play is a 1994 family musical by Alan Ayckbourn and John Pattison. It is set in a...
True label: Work

Sending prompt to llama3.2:
Prompt: Classify this text into one of these categories:
    Agent: A person, organization or fictional character
    Place: A location or space
    Species: A group of living things
    Work: A creative work or publication
    Event: An occurrence or happening

    Text: The Musical Jigsaw Play is a 1994 family musical by Alan Ayckbourn and John Pattison. It is set in a strange word where bad pop groups go, and the play involves help from the audience in solving puzzles to escape. Unlike most Ayckbourn plays, this play was only produced onc

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Similarity score: 0.774 (threshold: 0.8)
Attack failed

=== Running Brainwash Attack ===
Text: The Musical Jigsaw Play is a 1994 family musical by Alan Ayckbourn and John Pattison. It is set in a...
True label: Work
Testing with wrong labels: ['Agent', 'Place']

Trying wrong label: Agent
Iteration 1/3

Sending prompt to llama3.2:
Prompt: Classify this text into one of these categories:
            Agent: A person, organization or fictional character
            Place: A location or space
            Species: A group of living things
            Work: A creative work or publication
            Event: An occurrence or happening

            Text: The Musical Jigsaw Play is a 1994 family musical by Alan Ayckbourn and John Pattison. It is set in a strange word where bad pop groups go, and the play involves help from the audience in solving puzzles to escape. Unlike most Ayckbourn plays, this play was only produced once, largely due to the technical demands and lack of suitability for end-s

KeyboardInterrupt: 