<a href="https://colab.research.google.com/github/tanghongyi0406/CCNewsPDD/blob/main/Code/NAPDD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pip and Import


In [None]:

!pip uninstall -y torch torchvision torchaudio transformers datasets accelerate bitsandbytes
!pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121
!pip install transformers==4.37.2
!pip install peft==0.10.0
!pip install bitsandbytes==0.42.0
!pip install accelerate==0.27.0
!pip install datasets==2.19.0
!pip install matplotlib==3.9.0 \
            numpy==1.26.4 \
            pandas==2.2.2 \
            scikit_learn==1.4.2 \
            tqdm==4.66.2 \
            deepspeed==0.14.0

# restart Runtime1
import IPython
IPython.Application.instance().kernel.do_shutdown(True)

In [None]:

import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import pandas as pd
from scipy import stats
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import networkx as nx

In [None]:
import os
from google.colab import drive
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
from pathlib import Path
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from dataclasses import dataclass
from typing import Optional, List
from simple_parsing.helpers import Serializable, field

# Google Drive
drive.mount('/content/drive')

# Set the root folder to LLM_MIA.
base_dir = "/content/drive/MyDrive/LLM_MIA"

# Set the paths for each subfolder.
cache_dir = os.path.join(base_dir, "cache_dir")
data_dir = os.path.join(base_dir, "data")
results_dir = os.path.join(base_dir, "results")

# Set MIMIR environment variables
os.environ['MIMIR_CACHE_PATH'] = cache_dir
os.environ['MIMIR_DATA_SOURCE'] = data_dir

# Create the folder structure.
os.makedirs(cache_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)
os.makedirs(results_dir, exist_ok=True)

print(f"All data is stored in the data directory.: {base_dir}")



# Data

## Back Translate

In [None]:
import random
import zstandard as zstd
import json
import os
import io
from google.colab import drive
from tqdm import tqdm

from datasets import load_dataset, Dataset, concatenate_datasets

import torch
from transformers import MarianMTModel, MarianTokenizer

# ============== Part 1: Google Drive Mount & Utility Functions ==============

# Google Drive
drive.mount('/content/drive')

def read_zst_jsonl(file_path):
    try:
        with open(file_path, "rb") as f:
            dctx = zstd.ZstdDecompressor()
            with dctx.stream_reader(f) as reader:
                with io.TextIOWrapper(reader, encoding="utf-8") as text_reader:
                    for line in tqdm(text_reader, desc=f"load {os.path.basename(file_path)}"):
                        try:
                            yield line.strip()
                        except Exception as e:
                            print(f"Error processing line: {e}")
                            continue
    except Exception as e:
        print(f"Error reading file {file_path}: {e}")
        raise

def load_jsonl_to_dataset(file_path, filter_func=None, limit=None):
    try:
        data = []
        count = 0
        total_processed = 0
        rejected = 0

        for line in read_zst_jsonl(file_path):
            total_processed += 1
            try:
                item = json.loads(line)
                if filter_func is None or filter_func(item):
                    data.append(item)
                    count += 1
                    if limit and count >= limit:
                        break
                else:
                    rejected += 1
            except json.JSONDecodeError as e:
                print(f"Error decoding JSON line: {e}")
                continue

        print(f"Processed {total_processed} lines; after filtering, selected {count} samples and rejected {rejected} samples.")
        return data
    except Exception as e:
        print(f"Error creating dataset: {e}")
        raise

def is_pile_cc_data(item):
    try:
        content = item.get('text', '')
        if len(content) < 512:
            return item.get('meta', {}).get('pile_set_name') == "Pile-CC"
        else:
            return False
    except:
        return False

def stream_pile_cc_data(limit=None):
    count = 0
    pile_cc_data = []
    total_processed = 0
    rejected_long = 0

    ds_stream = load_dataset("monology/pile-uncopyrighted", split="train", streaming=True)

    for item in tqdm(ds_stream, desc="load Pile-CC data"):
        total_processed += 1
        content = item.get('text', '')

        if item['meta']['pile_set_name'] == "Pile-CC" and len(content) < 512:
            pile_cc_data.append(item)
            count += 1
            if limit and count >= limit:
                break
        elif item['meta']['pile_set_name'] == "Pile-CC":
            rejected_long += 1

    print(f"Total processed: {total_processed}; selected {len(pile_cc_data)} eligible Pile-CC entries.")
    print(f"Pile-CC entries rejected due to length > 512 characters: {rejected_long}")
    return pile_cc_data

def check_content_overlap(ds_train, ds_dev, ds_test):
    train_texts = set(ds_train['text'])
    dev_texts = set(ds_dev['text'])
    test_texts = set(ds_test['text'])

    train_dev_overlap = train_texts.intersection(dev_texts)
    train_test_overlap = train_texts.intersection(test_texts)
    dev_test_overlap = dev_texts.intersection(test_texts)

    return len(train_dev_overlap), len(train_test_overlap), len(dev_test_overlap)

def check_sample_lengths(dataset, name):
    lengths = [len(text) for text in dataset['text']]
    max_length = max(lengths)
    min_length = min(lengths)
    avg_length = sum(lengths) / len(lengths)

    over_512 = sum(1 for l in lengths if l >= 512)

    print(f"\n{name} length check:")
    print(f"Number of samples: {len(dataset)}")
    print(f"Min length: {min_length}, Max length: {max_length}, Avg length: {avg_length:.1f}")
    print(f"Samples ≥ 512 characters: {over_512}/{len(dataset)} ({over_512/len(dataset)*100:.1f}%)")

    lengths_0 = [len(dataset['text'][i]) for i in range(len(dataset)) if dataset['label'][i] == 0]
    lengths_1 = [len(dataset['text'][i]) for i in range(len(dataset)) if dataset['label'][i] == 1]

    print(f"label 0 sample: number={len(lengths_0)}, avg length={sum(lengths_0)/max(1,len(lengths_0)):.1f}")
    print(f"label 1 sample: number={len(lengths_1)}, avg length={sum(lengths_1)/max(1,len(lengths_1)):.1f}")

    return lengths

def filter_text_by_length(text, max_length=512):
    return len(text) < max_length

# ============== Part 2: Back-Translation (EN → FR → EN) — Required Functions & Model Initialization ==============
def load_translation_model(model_name):
    tokenizer = MarianTokenizer.from_pretrained(model_name)
    model = MarianMTModel.from_pretrained(
        model_name,
        device_map=None,
        low_cpu_mem_usage=False
    )
    return tokenizer, model

# Translation Model
en2fr_name = "Helsinki-NLP/opus-mt-en-fr"
fr2en_name = "Helsinki-NLP/opus-mt-fr-en"

tokenizer_en2fr, model_en2fr = load_translation_model(en2fr_name)

tokenizer_fr2en, model_fr2en = load_translation_model(fr2en_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
model_en2fr.to(device)
model_fr2en.to(device)

def translate(texts, tokenizer, model, max_length=512):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        generated = model.generate(**inputs, max_length=max_length)
    outputs = [tokenizer.decode(t, skip_special_tokens=True) for t in generated]
    return outputs

def back_translate_batch(batch):
    input_texts = batch["text"]
    # 1) E->F
    fr_texts = translate(input_texts, tokenizer_en2fr, model_en2fr)
    # 2) F->E
    en_texts = translate(fr_texts, tokenizer_fr2en, model_fr2en)
    return {"text": en_texts}

# ============== Part 3: Main Pipeline ==============
def main():
    data_dir = "/content/drive/MyDrive/LLM_MIA/data"
    os.makedirs(data_dir, exist_ok=True)

    try:
        # 1) load CC News data
        cc_news = load_dataset("vblagoje/cc_news")
        print(f"Loaded successfully. CC-News dataset size: {len(cc_news['train'])}")

        # 2) Filter CC-News texts with length < 512 characters
        filtered_cc_news_texts = []
        for item in tqdm(cc_news['train'], desc="Filtering CC-News dataset"):
            if len(item['text']) < 512:
                filtered_cc_news_texts.append(item['text'])

        print(f"Filtered CC-News dataset size: {len(filtered_cc_news_texts)}")
        if len(filtered_cc_news_texts) < 1000:
            raise ValueError(f"CC-News data is insufficient: only {len(filtered_cc_news_texts)} samples; at least 1000 are required.")

        # 3) load Pile-CC data
        pile_cc_data = stream_pile_cc_data(limit=1000)
        pile_cc_texts = [item['text'] for item in pile_cc_data]
        print(f"Number of Pile-CC: {len(pile_cc_texts)}")

        random.seed(42)

        # 4) split data
        random.shuffle(filtered_cc_news_texts)
        train_size = 200
        dev_nonmember_size = 200
        test_nonmember_size = 400

        required_cc_news = train_size + dev_nonmember_size + test_nonmember_size
        if len(filtered_cc_news_texts) < required_cc_news:
            raise ValueError(f"Insufficient CC-News data: need {required_cc_news} samples, only {len(filtered_cc_news_texts)} available.")

        train_cc_news = filtered_cc_news_texts[:train_size]
        dev_nonmember = filtered_cc_news_texts[train_size:train_size+dev_nonmember_size]
        test_nonmember = filtered_cc_news_texts[train_size+dev_nonmember_size:train_size+dev_nonmember_size+test_nonmember_size]

        # Pile-CC split
        dev_member_size = 200
        test_member_size = 400
        required_pile_cc = dev_member_size + test_member_size
        if len(pile_cc_texts) < required_pile_cc:
            raise ValueError(f"Insufficient Pile-CC data: need{required_pile_cc}samples, only{len(pile_cc_texts)}available.")

        random.shuffle(pile_cc_texts)
        dev_member = pile_cc_texts[:dev_member_size]
        test_member = pile_cc_texts[dev_member_size:dev_member_size+test_member_size]

        # build train: 200 samples -> 100(member)+100(nonmember)
        random.shuffle(train_cc_news)
        train_member = train_cc_news[:train_size//2]
        train_nonmember = train_cc_news[train_size//2:train_size]

        ds_train = Dataset.from_dict({
            'text': train_member + train_nonmember,
            'label': [1]* (train_size//2) + [0]*(train_size//2)
        })

        # bulid dev: 200(member)+200(nonmember)
        ds_dev = Dataset.from_dict({
            'text': dev_member + dev_nonmember,
            'label': [1]*dev_member_size + [0]*dev_nonmember_size
        })

        # bulid test: 400(member)+400(nonmember)
        ds_test = Dataset.from_dict({
            'text': test_member + test_nonmember,
            'label': [1]*test_member_size + [0]*test_nonmember_size
        })

        # ========== Check dataset overlap & print info ==========
        train_dev_overlap, train_test_overlap, dev_test_overlap = check_content_overlap(ds_train, ds_dev, ds_test)
        print("\nDataset overlap check:")
        print(f"Train-Dev overlap: {train_dev_overlap}")
        print(f"Train-Test overlap: {train_test_overlap}")
        print(f"Dev-Test overlap: {dev_test_overlap}")
        if train_dev_overlap > 0 or train_test_overlap > 0 or dev_test_overlap > 0:
            print("warning：Overlap detected between datasets!")

        check_sample_lengths(ds_train, "train")
        check_sample_lengths(ds_dev, "dev")
        check_sample_lengths(ds_test, "test")

        print("\n data size:")
        print(f"Train size: {len(ds_train)}")
        print(f"Dev size:   {len(ds_dev)}")
        print(f"Test size:  {len(ds_test)}")

        # ============== For comparability, keep `original_text` for now. ==============
        ds_train = ds_train.add_column("original_text", ds_train["text"])
        ds_dev = ds_dev.add_column("original_text", ds_dev["text"])
        ds_test = ds_test.add_column("original_text", ds_test["text"])

        # 5) Back-Translation
        # need：Train all back-translation；Dev/Test only back-translation nonmember (label=0)

        print("\n1) Back-translate training set (all samples)...")
        ds_train = ds_train.map(
            back_translate_batch,
            batched=True,
            batch_size=16,
            desc="Back-translate train (member + nonmember)"
        )

        print("\n2) Back-translate validation set nonmembers only...")
        # Split dev into dev_member / dev_nonmember
        ds_dev_member = ds_dev.filter(lambda x: x["label"] == 1)
        ds_dev_nonmember = ds_dev.filter(lambda x: x["label"] == 0)

        # Back-translate nonmembers
        ds_dev_nonmember = ds_dev_nonmember.map(
            back_translate_batch,
            batched=True,
            batch_size=16,
            desc="Back-translate dev nonmembers"
        )

        # Merge back into dev and shuffle
        ds_dev = concatenate_datasets([ds_dev_member, ds_dev_nonmember]).shuffle(seed=42)

        print("\n3) Back-translate test set nonmembers only...")
        ds_test_member = ds_test.filter(lambda x: x["label"] == 1)
        ds_test_nonmember = ds_test.filter(lambda x: x["label"] == 0)

        ds_test_nonmember = ds_test_nonmember.map(
            back_translate_batch,
            batched=True,
            batch_size=16,
            desc="Back-translate test nonmembers"
        )

        ds_test = concatenate_datasets([ds_test_member, ds_test_nonmember]).shuffle(seed=42)

        # 6) Save data
        ds_train.save_to_disk(os.path.join(data_dir, "pile_cc_mia_train_bt"))
        ds_dev.save_to_disk(os.path.join(data_dir, "pile_cc_mia_dev_bt"))
        ds_test.save_to_disk(os.path.join(data_dir, "pile_cc_mia_test_bt"))

        print("\nBack-translated datasets have been successfully saved to:")
        print(" -", os.path.join(data_dir, "pile_cc_mia_train_bt"))
        print(" -", os.path.join(data_dir, "pile_cc_mia_dev_bt"))
        print(" -", os.path.join(data_dir, "pile_cc_mia_test_bt"))

    except Exception as e:
        import traceback
        print(f"error: {str(e)}")
        traceback.print_exc()

if __name__ == "__main__":
    main()


In [None]:
# load data
data_dir = "/content/drive/MyDrive/LLM_MIA/data"

ds_train = load_from_disk(os.path.join(data_dir, "pile_cc_mia_train_bt"))
ds_dev = load_from_disk(os.path.join(data_dir, "pile_cc_mia_dev_bt"))
ds_test = load_from_disk(os.path.join(data_dir, "pile_cc_mia_test_bt"))

print(f"train size: {len(ds_train)}")
print(f"dev size: {len(ds_dev)}")
print(f"test size: {len(ds_test)}")

print(ds_test[0])

## Bert Mask

In [None]:
import random
import zstandard as zstd
import json
import os
import io
from google.colab import drive
from tqdm import tqdm

from datasets import load_dataset, Dataset, concatenate_datasets

import torch
from transformers import pipeline, AutoModelForMaskedLM, AutoTokenizer

###############################################################################
# Part 1: Google Drive Mount & Utility Functions
###############################################################################

# 1) Mount Google Drive (can be skipped if already mounted in Colab)
drive.mount('/content/drive')

def read_zst_jsonl(file_path):
    """
    Reads a .zst compressed JSONL file, yielding one string per line.
    """
    try:
        with open(file_path, "rb") as f:
            dctx = zstd.ZstdDecompressor()
            with dctx.stream_reader(f) as reader:
                with io.TextIOWrapper(reader, encoding="utf-8") as text_reader:
                    for line in tqdm(text_reader, desc=f"Loading {os.path.basename(file_path)}"):
                        try:
                            yield line.strip()
                        except Exception as e:
                            print(f"Error processing line: {e}")
                            continue
    except Exception as e:
        print(f"Error reading file {file_path}: {e}")
        raise

def load_jsonl_to_dataset(file_path, filter_func=None, limit=None):
    """
    Loads a zst-jsonl file into a list, with optional filtering and limiting.
    """
    try:
        data = []
        count = 0
        total_processed = 0
        rejected = 0

        for line in read_zst_jsonl(file_path):
            total_processed += 1
            try:
                item = json.loads(line)
                if filter_func is None or filter_func(item):
                    data.append(item)
                    count += 1
                    if limit and count >= limit:
                        break
                else:
                    rejected += 1
            except json.JSONDecodeError as e:
                print(f"Error decoding JSON line: {e}")
                continue

        print(f"Processed a total of {total_processed} lines, passed {count} samples, and rejected {rejected} samples")
        return data
    except Exception as e:
        print(f"Error creating dataset: {e}")
        raise

def is_pile_cc_data(item):
    """
    Checks if the data is from Pile-CC and filters for text length less than 512.
    """
    try:
        content = item.get('text', '')
        if len(content) < 512:
            return item.get('meta', {}).get('pile_set_name') == "Pile-CC"
        else:
            return False
    except:
        return False

def stream_pile_cc_data(limit=None):
    """
    Streams Pile-CC data, ensuring the length is less than 512.
    """
    count = 0
    pile_cc_data = []
    total_processed = 0
    rejected_long = 0

    ds_stream = load_dataset("monology/pile-uncopyrighted", split="train", streaming=True)

    for item in tqdm(ds_stream, desc="Loading Pile-CC data"):
        total_processed += 1
        content = item.get('text', '')

        if item['meta']['pile_set_name'] == "Pile-CC" and len(content) < 512:
            pile_cc_data.append(item)
            count += 1
            if limit and count >= limit:
                break
        elif item['meta']['pile_set_name'] == "Pile-CC":
            rejected_long += 1

    print(f"Processed a total of {total_processed} items, filtered {len(pile_cc_data)} qualifying Pile-CC items")
    print(f"Pile-CC items rejected due to length > 512 characters: {rejected_long}")
    return pile_cc_data

def check_content_overlap(ds_train, ds_dev, ds_test):
    """
    Checks for text overlap between different datasets.
    """
    train_texts = set(ds_train['text'])
    dev_texts = set(ds_dev['text'])
    test_texts = set(ds_test['text'])

    train_dev_overlap = train_texts.intersection(dev_texts)
    train_test_overlap = train_texts.intersection(ds_test['text'])
    dev_test_overlap   = dev_texts.intersection(ds_test['text'])

    return len(train_dev_overlap), len(train_test_overlap), len(dev_test_overlap)

def check_sample_lengths(dataset, name):
    """
    Checks the length of all samples in the dataset.
    """
    lengths = [len(text) for text in dataset['text']]
    max_length = max(lengths)
    min_length = min(lengths)
    avg_length = sum(lengths) / len(lengths)

    over_512 = sum(1 for l in lengths if l >= 512)

    print(f"\n{name} Length Check:")
    print(f"Number of samples: {len(dataset)}")
    print(f"Min length: {min_length}, Max length: {max_length}, Avg length: {avg_length:.1f}")
    print(f"Samples >= 512 characters: {over_512}/{len(dataset)} ({(over_512/len(dataset))*100:.1f}%)")

    # Check the length distribution for label=0 and label=1 samples respectively
    lengths_0 = [len(dataset['text'][i]) for i in range(len(dataset)) if dataset['label'][i] == 0]
    lengths_1 = [len(dataset['text'][i]) for i in range(len(dataset)) if dataset['label'][i] == 1]

    print(f"Label 0 samples: Count={len(lengths_0)}, Avg length={sum(lengths_0)/max(1,len(lengths_0)):.1f}")
    print(f"Label 1 samples: Count={len(lengths_1)}, Avg length={sum(lengths_1)/max(1,len(lengths_1)):.1f}")

    return lengths

def filter_text_by_length(text, max_length=512):
    """
    Filters text, keeping samples with length less than a specified value.
    """
    return len(text) < max_length


###############################################################################
# Part 2: Text Rewriting Using BERT Model
###############################################################################

# Initialize BERT model
print("\n== Initializing BERT model for text rewriting ==")
model_name = "bert-base-uncased"  # Can be replaced with other BERT variants, e.g., "bert-large-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

# Check if a GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

def bert_rewrite_text(text, mask_prob=0.15, max_length=512):
    """
    Rewrites text using the BERT model.
    Generates slightly different text by randomly masking some words and having BERT predict them.
    """
    # Ensure the text is not empty
    if not text or len(text.strip()) == 0:
        return text

    # Tokenize the text
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Get token IDs
    input_ids = inputs["input_ids"].clone()

    # Create a random mask
    # Exclude [CLS], [SEP] and other special tokens
    special_tokens_mask = tokenizer.get_special_tokens_mask(input_ids[0], already_has_special_tokens=True)
    special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)

    # Randomly select tokens to mask
    probability_matrix = torch.full(input_ids.shape, mask_prob)
    probability_matrix[0][special_tokens_mask] = 0.0
    masked_indices = torch.bernoulli(probability_matrix).bool()

    # Remember the original masked tokens
    original_masked_tokens = input_ids[masked_indices]

    # Apply the mask
    input_ids[masked_indices] = tokenizer.mask_token_id

    # Use the model to predict the masked tokens
    with torch.no_grad():
        outputs = model(input_ids)

    # Get the prediction results
    predictions = outputs.logits.argmax(dim=-1)

    # Replace the masked positions with the predicted tokens
    input_ids[masked_indices] = predictions[masked_indices]

    # Decode back to text
    rewritten_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)

    return rewritten_text

def prompt_rewrite_batch(batch):
    """
    Batch rewrite text using the BERT model.
    """
    new_texts = []
    for text in batch["text"]:
        rewrite = bert_rewrite_text(text)
        new_texts.append(rewrite)
    return {"text": new_texts}


###############################################################################
# Part 3: Main Workflow
###############################################################################

def main():
    data_dir = "/content/drive/MyDrive/LLM_MIA/data"
    os.makedirs(data_dir, exist_ok=True)

    try:
        # 1) Load CC News data
        print("Starting to load CC News dataset...")
        cc_news = load_dataset("vblagoje/cc_news")
        print(f"Loading complete. CC News dataset size: {len(cc_news['train'])}")

        # 2) Filter CC News for texts with length < 512
        filtered_cc_news_texts = []
        for item in tqdm(cc_news['train'], desc="Filtering CC News data"):
            if len(item['text']) < 512:
                filtered_cc_news_texts.append(item['text'])

        print(f"Filtered CC News data size: {len(filtered_cc_news_texts)}")
        if len(filtered_cc_news_texts) < 1000:
            raise ValueError(f"Insufficient CC News data. Found {len(filtered_cc_news_texts)} records, but at least 1000 are required.")

        # 3) Load Pile-CC data (limit=1000)
        pile_cc_data = stream_pile_cc_data(limit=1000)
        pile_cc_texts = [item['text'] for item in pile_cc_data]
        print(f"Pile-CC data size: {len(pile_cc_texts)}")

        random.seed(42)

        # 4) Split the datasets
        random.shuffle(filtered_cc_news_texts)
        train_size = 200
        dev_nonmember_size = 200
        test_nonmember_size = 400

        required_cc_news = train_size + dev_nonmember_size + test_nonmember_size
        if len(filtered_cc_news_texts) < required_cc_news:
            raise ValueError(f"Insufficient CC News data. Need {required_cc_news} records, but only have {len(filtered_cc_news_texts)}.")

        train_cc_news = filtered_cc_news_texts[:train_size]
        dev_nonmember = filtered_cc_news_texts[train_size:train_size+dev_nonmember_size]
        test_nonmember = filtered_cc_news_texts[train_size+dev_nonmember_size : train_size+dev_nonmember_size+test_nonmember_size]

        # Pile-CC split
        dev_member_size = 200
        test_member_size = 400
        required_pile_cc = dev_member_size + test_member_size
        if len(pile_cc_texts) < required_pile_cc:
            raise ValueError(f"Insufficient Pile-CC data. Need {required_pile_cc} records, but only have {len(pile_cc_texts)}.")

        random.shuffle(pile_cc_texts)
        dev_member  = pile_cc_texts[:dev_member_size]
        test_member = pile_cc_texts[dev_member_size : dev_member_size+test_member_size]

        # Build train: 200 records -> 100(member)+100(nonmember)
        random.shuffle(train_cc_news)
        train_member = train_cc_news[: train_size//2]
        train_nonmember = train_cc_news[train_size//2 : train_size]

        ds_train = Dataset.from_dict({
            'text':  train_member + train_nonmember,
            'label': [1]*(train_size//2) + [0]*(train_size//2)
        })

        # Build dev: 200(member)+200(nonmember)
        ds_dev = Dataset.from_dict({
            'text':  dev_member + dev_nonmember,
            'label': [1]*dev_member_size + [0]*dev_nonmember_size
        })

        # Build test: 400(member)+400(nonmember)
        ds_test = Dataset.from_dict({
            'text':  test_member + test_nonmember,
            'label': [1]*test_member_size + [0]*test_nonmember_size
        })

        # ========== Check for dataset overlap and print info ==========
        train_dev_overlap, train_test_overlap, dev_test_overlap = check_content_overlap(ds_train, ds_dev, ds_test)
        print("\nDataset Overlap Check:")
        print(f"Train-Dev overlap: {train_dev_overlap}")
        print(f"Train-Test overlap: {train_test_overlap}")
        print(f"Dev-Test overlap: {dev_test_overlap}")
        if train_dev_overlap > 0 or train_test_overlap > 0 or dev_test_overlap > 0:
            print("Warning: Overlap detected between datasets!")

        check_sample_lengths(ds_train, "Training Set")
        check_sample_lengths(ds_dev,   "Development Set")
        check_sample_lengths(ds_test,  "Test Set")

        print("\nDataset Sizes:")
        print(f"Train size: {len(ds_train)}")
        print(f"Dev size:   {len(ds_dev)}")
        print(f"Test size:  {len(ds_test)}")

        # ============== Keep original text for comparison
        ds_train = ds_train.add_column("original_text", ds_train["text"])
        ds_dev   = ds_dev.add_column("original_text", ds_dev["text"])
        ds_test  = ds_test.add_column("original_text", ds_test["text"])

        #####################################################################
        # Part 5: Rewriting with BERT Model
        # Rewrite all of Train; Rewrite only non-members (label=0) for Dev/Test
        #####################################################################

        print("\n1) Rewriting the training set (members + non-members)...")
        ds_train = ds_train.map(
            prompt_rewrite_batch,
            batched=True,
            batch_size=16,
            desc="Rewriting train"
        )

        print("\n2) Rewriting the development set (non-members only)...")
        ds_dev_member    = ds_dev.filter(lambda x: x["label"] == 1)
        ds_dev_nonmember = ds_dev.filter(lambda x: x["label"] == 0)

        ds_dev_nonmember = ds_dev_nonmember.map(
            prompt_rewrite_batch,
            batched=True,
            batch_size=16,
            desc="Rewriting dev non-members"
        )

        # Combine and shuffle
        ds_dev = concatenate_datasets([ds_dev_member, ds_dev_nonmember]).shuffle(seed=42)

        print("\n3) Rewriting the test set (non-members only)...")
        ds_test_member    = ds_test.filter(lambda x: x["label"] == 1)
        ds_test_nonmember = ds_test.filter(lambda x: x["label"] == 0)

        ds_test_nonmember = ds_test_nonmember.map(
            prompt_rewrite_batch,
            batched=True,
            batch_size=16,
            desc="Rewriting test non-members"
        )

        ds_test = concatenate_datasets([ds_test_member, ds_test_nonmember]).shuffle(seed=42)

        #####################################################################
        # Part 6: Reviewing Examples for Comparison
        #####################################################################
        print("\n==== Reviewing before-and-after examples (Training Set) ====")
        for _ in range(3):
            idx = random.randint(0, len(ds_train)-1)
            orig = ds_train[idx]["original_text"]
            newt = ds_train[idx]["text"]
            lab  = ds_train[idx]["label"]
            print(f"[Train idx={idx}, label={lab}]")
            print("Original:      ", orig)
            print("After Rewrite: ", newt)
            print("----")

        print("\n==== Reviewing before-and-after examples (Development Set) ====")
        for _ in range(3):
            idx = random.randint(0, len(ds_dev)-1)
            orig = ds_dev[idx]["original_text"]
            newt = ds_dev[idx]["text"]
            lab  = ds_dev[idx]["label"]
            print(f"[Dev idx={idx}, label={lab}]")
            print("Original:      ", orig)
            print("After Rewrite: ", newt)
            print("----")

        print("\n==== Reviewing before-and-after examples (Test Set) ====")
        for _ in range(3):
            idx = random.randint(0, len(ds_test)-1)
            orig = ds_test[idx]["original_text"]
            newt = ds_test[idx]["text"]
            lab  = ds_test[idx]["label"]
            print(f"[Test idx={idx}, label={lab}]")
            print("Original:      ", orig)
            print("After Rewrite: ", newt)
            print("----")

        #####################################################################
        # Part 7: Saving the Rewritten Datasets
        #####################################################################
        ds_train.save_to_disk(os.path.join(data_dir, "pile_cc_mia_train_bert"))
        ds_dev.save_to_disk(os.path.join(data_dir, "pile_cc_mia_dev_bert"))
        ds_test.save_to_disk(os.path.join(data_dir, "pile_cc_mia_test_bert"))

        print("\nRewritten datasets have been successfully saved to:")
        print(" -", os.path.join(data_dir, "pile_cc_mia_train_bert"))
        print(" -", os.path.join(data_dir, "pile_cc_mia_dev_bert"))
        print(" -", os.path.join(data_dir, "pile_cc_mia_test_bert"))

    except Exception as e:
        import traceback
        print(f"An error occurred: {str(e)}")
        traceback.print_exc()

###############################################################################
# Run the main function
###############################################################################
if __name__ == "__main__":
    main()

In [None]:
# load data
data_dir = "/content/drive/MyDrive/LLM_MIA/data"

ds_train = load_from_disk(os.path.join(data_dir, "pile_cc_mia_train_bert"))
ds_dev = load_from_disk(os.path.join(data_dir, "pile_cc_mia_dev_bert"))
ds_test = load_from_disk(os.path.join(data_dir, "pile_cc_mia_test_bert"))

print(f"train size: {len(ds_train)}")
print(f"dev size: {len(ds_dev)}")
print(f"test size: {len(ds_test)}")

## GPT Prompt

In [None]:
# load data
data_dir = "/content/drive/MyDrive/LLM_MIA/data"

ds_train = load_from_disk(os.path.join(data_dir, "pile_cc_mia_train_prompt_gpt4o_simple_v1"))
ds_dev = load_from_disk(os.path.join(data_dir, "pile_cc_mia_dev_prompt_gpt4o_simple_v1"))
ds_test = load_from_disk(os.path.join(data_dir, "pile_cc_mia_test_prompt_gpt4o_simple_v1"))

print(f"train size: {len(ds_train)}")
print(f"dev size: {len(ds_dev)}")
print(f"test size: {len(ds_test)}")


## WikiMIA

In [None]:
from datasets import load_dataset
import numpy as np
from datasets import Dataset

# Load the dataset
ds = load_dataset("swj0419/WikiMIA")

# Select only the length32 data
length32_data = ds["WikiMIA_length32"]

# Separate member (label=1) and non-member (label=0) samples, and rename the 'input' field to 'text'
member_samples = []
nonmember_samples = []

for item in length32_data:
    # Create a new sample dictionary, changing the 'input' field to 'text'
    new_item = {
        'text': item['input'],
        'label': item['label']
    }
    # Keep other existing fields
    for key, value in item.items():
        if key not in ['input', 'text', 'label']:
            new_item[key] = value

    # Classify based on the label
    if item['label'] == 1:
        member_samples.append(new_item)
    else:
        nonmember_samples.append(new_item)

print(f"Number of member samples: {len(member_samples)}")
print(f"Number of non-member samples: {len(nonmember_samples)}")

# Adjust dataset sizes
# Use the number of available samples to determine the split
available_member_count = len(member_samples)
available_nonmember_count = len(nonmember_samples)

# Set the number of member samples in the dev and test sets
# Considering there are only 387 member samples, we can take 140 for dev and the remaining 247 for test
dev_member_count = 140
test_member_count = available_member_count - dev_member_count  # should be 247

# Set the same number of non-member samples to maintain balance
dev_nonmember_count = dev_member_count
test_nonmember_count = test_member_count

# Set a random seed for reproducibility
np.random.seed(42)

# Split the member samples
dev_members = member_samples[:dev_member_count]
test_members = member_samples[dev_member_count:]

# Split the non-member samples
dev_nonmembers = nonmember_samples[:dev_nonmember_count]
test_nonmembers = nonmember_samples[dev_nonmember_count:dev_nonmember_count+test_nonmember_count]

# Combine the dev set
ds_dev = dev_members + dev_nonmembers
np.random.shuffle(ds_dev)  # Shuffle the order

# Combine the test set
ds_test = test_members + test_nonmembers
np.random.shuffle(ds_test)  # Shuffle the order

# Verify dataset sizes and label balance
print(f"Dev set size: {len(ds_dev)}")
print(f"Number of member samples in Dev set: {sum(1 for item in ds_dev if item['label'] == 1)}")
print(f"Number of non-member samples in Dev set: {sum(1 for item in ds_dev if item['label'] == 0)}")

print(f"Test set size: {len(ds_test)}")
print(f"Number of member samples in Test set: {sum(1 for item in ds_test if item['label'] == 1)}")
print(f"Number of non-member samples in Test set: {sum(1 for item in ds_test if item['label'] == 0)}")

# Save the datasets
ds_dev = Dataset.from_list(ds_dev)
ds_test = Dataset.from_list(ds_test)

# Create the data directory if it doesn't exist
import os
os.makedirs("data", exist_ok=True)

# Save to disk
ds_dev.save_to_disk("data/ds_dev")
ds_test.save_to_disk("data/ds_test")

print("Datasets have been saved to the data/ds_dev and data/ds_test directories")

## ArxivMia

In [None]:
from datasets import load_dataset, Dataset
import shutil
from google.colab import drive

drive.mount('/content/drive')
data_dir = '/content/drive/MyDrive/data'
cache_dir = '/content/drive/MyDrive/cache_dir'

def check_content_overlap(ds_train, ds_dev, ds_test):
    train_texts = set(ds_train['text'])
    dev_texts   = set(ds_dev['text'])
    test_texts  = set(ds_test['text'])

    train_dev_overlap = train_texts.intersection(dev_texts)
    train_test_overlap = train_texts.intersection(test_texts)
    dev_test_overlap   = dev_texts.intersection(test_texts)

    return len(train_dev_overlap), len(train_test_overlap), len(dev_test_overlap)

if __name__ == "__main__":
    # 1. Load local train data
    ds_train = load_dataset(
        "json",
        data_files="/content/drive/MyDrive/LLM_MIA/arxiv_mia_train_real.jsonl",
        split="train"
    )

    # 2. Load dev, test from Hugging Face
    ds_dev  = load_dataset("zhliu/ArxivMIA", "arxiv_mia_dev")["train"]
    ds_test = load_dataset("zhliu/ArxivMIA", "arxiv_mia_test")["train"]

    print("Train size:", len(ds_train))
    print("Dev size:  ", len(ds_dev))
    print("Test size: ", len(ds_test))

    # 3. Check for overlap between datasets
    train_dev_overlap, train_test_overlap, dev_test_overlap = check_content_overlap(ds_train, ds_dev, ds_test)
    print("\nChecking dataset overlap:")
    print(f"Train-Dev overlap: {train_dev_overlap}")
    print(f"Train-Test overlap: {train_test_overlap}")
    print(f"Dev-Test overlap: {dev_test_overlap}")

# Model

## ArxivMia

### TinyLlama

In [None]:
# Import necessary libraries
import os
import torch
import numpy as np
from collections import Counter
from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics import roc_curve, auc, accuracy_score, precision_recall_fscore_support
import gc
import warnings
warnings.filterwarnings("ignore")

# Disable tqdm - a safer way
import sys
import importlib
import tqdm as tqdm_module

# Save the original module before loading tqdm
original_module = sys.modules.get('tqdm', None)

# Create a no-op tqdm substitute
class DummyTqdmModule:
    def __init__(self, *args, **kwargs):
        pass

    def update(self, *args, **kwargs):
        pass

    def close(self, *args, **kwargs):
        pass

    def __iter__(self):
        return self

    def __next__(self):
        raise StopIteration

# Override all methods of the tqdm class
def dummy_tqdm(*args, **kwargs):
    if len(args) > 0 and isinstance(args[0], list):
        return args[0]
    return DummyTqdmModule()

# Patch all possible tqdm versions
for name in ['tqdm', 'tqdm.std', 'tqdm.auto', 'tqdm.notebook', 'tqdm.rich', 'tqdm.cli', 'tqdm.gui', 'tqdm.keras']:
    try:
        if name in sys.modules:
            module = sys.modules[name]
            module.tqdm = dummy_tqdm
    except:
        pass

tqdm_module.tqdm = dummy_tqdm

# 1. Set up the environment
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 3. Load the model - changed to TinyLlama
model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Ensure the model is on the correct device
model = model.to(device)
model.eval()  # Set to evaluation mode

# If the tokenizer does not have a padding token, set one
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# 4. Set up hook functions to collect neuron activations
activations = {}
activated_neurons = {}
ACTIVATION_THRESHOLD = 0  # Set activation threshold

def get_ffn_activation(name):
    def hook(module, input, output):
        if output.dim() >= 2:
            # If it's a 3D tensor [batch_size, seq_len, hidden_dim]
            if output.dim() == 3:
                # Get only the indices after masking
                activation_mask = (output > ACTIVATION_THRESHOLD).detach().cpu()
                activated_neurons[name] = activation_mask
            # If it's a 2D tensor [batch_size, hidden_dim]
            elif output.dim() == 2:
                activation_mask = (output > ACTIVATION_THRESHOLD).detach().cpu()
                activated_neurons[name] = activation_mask
    return hook

# 5. Register hooks for the model
hooks = []

# Register hooks for the activation function in LlamaMLP
# TinyLlama uses the Llama architecture, and the MLP part has act_fn as the activation function
for i, layer in enumerate(model.model.layers):
    # Register MLP activation function hook
    if hasattr(layer.mlp, 'act_fn'):
        # Set hook for SiLU activation function
        hook = layer.mlp.act_fn.register_forward_hook(get_ffn_activation(f'layer_{i}_mlp_act'))
        hooks.append(hook)
    else:
        # If the 'act_fn' attribute does not exist, try to find other activation layers
        for name, module in layer.mlp.named_modules():
            if isinstance(module, torch.nn.SiLU) or isinstance(module, torch.nn.GELU):
                hook = module.register_forward_hook(get_ffn_activation(f'layer_{i}_mlp_act_{name}'))
                hooks.append(hook)
                break

print(f"Registered hooks for {len(model.model.layers)}-layer model")

# 7. Collect sample activation data - simplified version, only storing activated neuron indices
def process_sample(sample, sample_id):
    """Process a single sample and collect only the index information of activated neurons"""
    # Clear previous activations
    activations.clear()
    activated_neurons.clear()

    # Prepare input for the sample
    encodings = tokenizer(sample['text'], return_tensors="pt", truncation=True, max_length=256)
    encodings = {k: v.to(device) for k, v in encodings.items()}

    # Run the model
    with torch.no_grad():
        outputs = model(**encodings)

    # Collect only the indices of activated neurons, do not store activation values
    sample_neural_signature = {}
    for key, value in activated_neurons.items():
        try:
            if len(value.shape) >= 2:
                mask = value.squeeze(0).numpy()

                # Only collect which neurons are activated in each layer, without recording position or activation value
                if len(mask.shape) == 2:  # [seq_len, hidden_dim]
                    # Merge the activation status across all positions
                    # Any neuron activated at any position is considered activated
                    active_neurons = set()
                    for pos in range(mask.shape[0]):
                        active_indices = np.where(mask[pos])[0]
                        active_neurons.update(active_indices)

                    if active_neurons:
                        sample_neural_signature[key] = list(active_neurons)
                elif len(mask.shape) == 1:  # [hidden_dim]
                    active_indices = np.where(mask)[0]
                    if len(active_indices) > 0:
                        sample_neural_signature[key] = active_indices.tolist()
        except Exception as e:
            pass

    # Return sample information, containing only the label and neural signature
    return {
        'sample_id': sample_id,
        'label': sample['label'],
        'neural_signature': sample_neural_signature
    }

def collect_activations(samples, batch_size=10):
    """Collect activation data for samples in batches"""
    results = []

    # Process samples in batches to save memory
    for i in range(0, len(samples), batch_size):
        batch = samples[i:i+batch_size]
        batch_results = []

        for j, sample in enumerate(batch):
            try:
                result = process_sample(sample, i + j)
                batch_results.append(result)
            except Exception as e:
                pass

        results.extend(batch_results)

        # Clean up memory
        if i + batch_size < len(samples):  # Not the last batch
            del batch_results
            gc.collect()
            torch.cuda.empty_cache()

    return results

# 8. Analyze neuron activation patterns - simplified version, only analyzing neuron activation indices
def analyze_neuron_activation_patterns(member_samples, nonmember_samples):
    """Analyze the activation patterns of neurons in member and non-member samples"""
    results = {}

    # Get all layer names
    layer_names = set()
    for sample in member_samples + nonmember_samples:
        layer_names.update(sample['neural_signature'].keys())

    # Analyze layer by layer
    for layer_name in layer_names:
        # Count activated neurons in member and non-member samples
        member_neuron_counts = Counter()
        nonmember_neuron_counts = Counter()

        # Activated neurons in member samples
        for sample in member_samples:
            if layer_name in sample['neural_signature']:
                member_neuron_counts.update(sample['neural_signature'][layer_name])

        # Activated neurons in non-member samples
        for sample in nonmember_samples:
            if layer_name in sample['neural_signature']:
                nonmember_neuron_counts.update(sample['neural_signature'][layer_name])

        # Calculate the activation frequency for each neuron
        member_freq = {n: count / len(member_samples) for n, count in member_neuron_counts.items()}
        nonmember_freq = {n: count / len(nonmember_samples) for n, count in nonmember_neuron_counts.items()}

        # Identify neurons predominantly activated in member samples
        member_dominant = {}
        for neuron, freq in member_freq.items():
            if neuron not in nonmember_freq or freq > nonmember_freq[neuron] * 1.5:
                member_dominant[neuron] = freq

        # Identify neurons predominantly activated in non-member samples
        nonmember_dominant = {}
        for neuron, freq in nonmember_freq.items():
            if neuron not in member_freq or freq > member_freq[neuron] * 1.5:
                nonmember_dominant[neuron] = freq

        # Identify neurons frequently activated in both types of samples
        common_neurons = {}
        for neuron in set(member_freq.keys()) & set(nonmember_freq.keys()):
            if neuron not in member_dominant and neuron not in nonmember_dominant:
                common_neurons[neuron] = (member_freq[neuron], nonmember_freq[neuron])

        # Save the results
        results[layer_name] = {
            'member_dominant': member_dominant,
            'nonmember_dominant': nonmember_dominant,
            'common_neurons': common_neurons,
            'member_freq': member_freq,
            'nonmember_freq': nonmember_freq
        }

    return results

# 9. Build reference patterns
def build_reference_patterns(train_samples):
    """Build reference activation patterns using training samples"""
    # Separate member and non-member samples
    member_samples = [s for s in train_samples if s['label'] == 1]
    nonmember_samples = [s for s in train_samples if s['label'] == 0]

    print(f"Building reference patterns: {len(member_samples)} member samples, {len(nonmember_samples)} non-member samples")

    # Analyze activation pattern differences
    reference_patterns = analyze_neuron_activation_patterns(member_samples, nonmember_samples)

    return reference_patterns

# 10. Calculate discrimination score for each layer
def calculate_layer_discrimination_scores(reference_patterns):
    """Calculate the discriminative power score for each layer"""
    layer_scores = {}

    # Calculate the discriminative power score for each layer
    for layer_name, data in reference_patterns.items():
        # Simple discrimination score: number of member-dominant neurons minus number of non-member-dominant neurons
        # If the score is greater than 0, it indicates the layer is more inclined to identify member samples
        member_dominant_count = len(data['member_dominant'])
        nonmember_dominant_count = len(data['nonmember_dominant'])

        discrimination_score = member_dominant_count - nonmember_dominant_count
        layer_scores[layer_name] = discrimination_score

    return layer_scores

# 11. Select the most discriminative layers
def select_discriminative_layers(layer_scores, top_n=10):
    """Select the most discriminative layers"""
    # Sort by discrimination score
    sorted_scores = sorted(layer_scores.items(), key=lambda x: abs(x[1]), reverse=True)

    # Select the top N layers
    selected_layers = [layer for layer, score in sorted_scores[:top_n]]

    return selected_layers

# 12. Membership prediction based on relative ratio - using simplified neural signatures
def predict_membership_by_relative_ratio(sample, reference_patterns, discriminative_layers, threshold=1.0):
    """Predict whether a single sample is a member using the relative ratio method"""
    if not sample['neural_signature']:
        return 0, 0, 0  # If there is no activation data, default to non-member

    # Initialize relative ratios
    layers_counted = 0
    total_member_ratio = 0
    total_nonmember_ratio = 0

    # Analyze layer by layer
    for layer_name in discriminative_layers:
        if layer_name not in sample['neural_signature'] or layer_name not in reference_patterns:
            continue

        # Get the reference pattern for this layer
        layer_data = reference_patterns[layer_name]

        # Get all neurons activated by this sample in this layer
        sample_neurons = set(sample['neural_signature'][layer_name])

        if not sample_neurons:
            continue

        # Calculate the relative overlap with member-dominant neurons
        member_dominant_set = set(layer_data['member_dominant'].keys())
        member_overlap = len(sample_neurons.intersection(member_dominant_set))

        # Calculate the relative overlap with non-member-dominant neurons
        nonmember_dominant_set = set(layer_data['nonmember_dominant'].keys())
        nonmember_overlap = len(sample_neurons.intersection(nonmember_dominant_set))

        # Calculate relative ratios
        member_ratio = member_overlap / len(member_dominant_set) if len(member_dominant_set) > 0 else 0
        nonmember_ratio = nonmember_overlap / len(nonmember_dominant_set) if len(nonmember_dominant_set) > 0 else 0

        # Accumulate relative ratios
        total_member_ratio += member_ratio
        total_nonmember_ratio += nonmember_ratio
        layers_counted += 1

    # If there are no valid layers, default to non-member
    if layers_counted == 0:
        return 0, 0, 0

    # Calculate average relative ratios
    avg_member_ratio = total_member_ratio / layers_counted
    avg_nonmember_ratio = total_nonmember_ratio / layers_counted

    # Calculate the ratio of relative ratios
    if avg_nonmember_ratio == 0:
        ratio = float('inf')  # Avoid division by zero
    else:
        ratio = avg_member_ratio / avg_nonmember_ratio

    # Return the prediction and related ratios
    return 1 if ratio >= threshold else 0, avg_member_ratio, avg_nonmember_ratio

# 14. Find the best threshold on the validation set
def find_best_relative_ratio_threshold(val_samples, reference_patterns, discriminative_layers):
    """Find the best threshold on the validation set using the relative ratio method"""
    # Calculate relative ratios for all validation samples
    ratios = []
    labels = []

    for sample in val_samples:
        pred, member_ratio, nonmember_ratio = predict_membership_by_relative_ratio(
            sample, reference_patterns, discriminative_layers, threshold=1.0
        )

        # Calculate the ratio of relative ratios
        ratio = float('inf') if nonmember_ratio == 0 else member_ratio / nonmember_ratio

        ratios.append(ratio)
        labels.append(sample['label'])

    # Filter out infinite values
    filtered_ratios = []
    filtered_labels = []
    for r, l in zip(ratios, labels):
        if r != float('inf') and not np.isnan(r):
            filtered_ratios.append(r)
            filtered_labels.append(l)

    # Check if there are valid ratio values
    if not filtered_ratios:
        print("Warning: No valid ratio values found. Using default threshold 1.0")
        return 1.0, {
            'accuracy': 0.5,  # Default accuracy
            'precision': 0.0,
            'recall': 0.0,
            'f1': 0.0
        }

    # Create candidate thresholds
    min_ratio = min(filtered_ratios)
    max_ratio = max(filtered_ratios)

    # Generate uniformly distributed thresholds
    candidate_thresholds = list(np.linspace(min_ratio, max_ratio, 100))
    # Add some important threshold points
    candidate_thresholds += [0.5, 1.0, 1.5, 2.0, 3.0, 5.0]
    candidate_thresholds = sorted(set(candidate_thresholds))

    # Find the best threshold
    best_threshold = 1.0
    best_accuracy = 0
    best_metrics = None

    results = []
    for threshold in candidate_thresholds:
        # Make predictions using the current threshold
        predictions = [1 if r >= threshold else 0 for r in ratios]

        # Calculate performance metrics
        accuracy = accuracy_score(labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')

        # Record the results
        results.append({
            'threshold': threshold,
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1
        })

        # Update the best threshold
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_threshold = threshold
            best_metrics = {
                'accuracy': accuracy,
                'precision': precision,
                'recall': recall,
                'f1': f1
            }

    return best_threshold, best_metrics

# Calculate AUC for the test set
def calculate_auc(test_samples, reference_patterns, discriminative_layers, batch_size=10):
    """Calculate the AUC value for the test set"""
    # Process test samples in batches
    all_ratios = []
    all_labels = []

    for i in range(0, len(test_samples), batch_size):
        batch = test_samples[i:i+batch_size]

        # Collect activation data
        batch_acts = collect_activations(batch, batch_size=batch_size)

        for sample in batch_acts:
            # Calculate relative ratios and predict
            _, member_ratio, nonmember_ratio = predict_membership_by_relative_ratio(
                sample, reference_patterns, discriminative_layers
            )

            # Calculate the ratio of relative ratios
            ratio = float('inf') if nonmember_ratio == 0 else member_ratio / nonmember_ratio

            # Record the results
            all_ratios.append(ratio)
            all_labels.append(sample['label'])

        # Release memory
        del batch_acts
        gc.collect()
        torch.cuda.empty_cache()

    # Filter out infinite values for ROC analysis
    filtered_ratios = []
    filtered_labels = []
    for ratio, label in zip(all_ratios, all_labels):
        if ratio != float('inf') and not np.isnan(ratio):
            filtered_ratios.append(ratio)
            filtered_labels.append(label)

    # Handle the case of no valid ratios
    if not filtered_ratios:
        print("Warning: No valid ratio values found. Cannot calculate AUC.")
        return 0.5  # Return a default value, equivalent to the AUC of a random guess

    # Calculate AUC
    fpr, tpr, _ = roc_curve(filtered_labels, filtered_ratios)
    roc_auc = auc(fpr, tpr)

    return roc_auc

# Simply iterate over samples without using tqdm
def iterate_samples(samples):
    """Simply iterate over samples, avoiding tqdm"""
    return samples

# Dataset filtering functions
def filter_math_papers(dataset):
    """Filter dataset to only include math papers"""
    return dataset.filter(lambda x: x['field'] == 'math')

def filter_cs_papers(dataset):
    """Filter dataset to only include CS papers"""
    return dataset.filter(lambda x: x['field'] == 'cs')

# Process all thresholds for a single dataset
def process_dataset(data_type, data_dir="/content/drive/MyDrive/LLM_MIA/data", thresholds=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0]):
    """Process all thresholds for a single dataset"""
    global ACTIVATION_THRESHOLD

    print(f"\n===== Dataset: {data_type} =====")

    # Load dataset - changed to Arxiv dataset
    if data_type == "arxiv_all":
        # 1. Load local training data
        ds_train = load_dataset(
            "json",
            data_files="/content/drive/MyDrive/LLM_MIA/arxiv_mia_train_real.jsonl",
            split="train"
        )
        # 2. Load dev and test from huggingface
        ds_dev = load_dataset("zhliu/ArxivMIA", "arxiv_mia_dev")["train"]
        ds_test = load_dataset("zhliu/ArxivMIA", "arxiv_mia_test")["train"]

    elif data_type == "arxiv_math":
        # 1. Load local training data and filter for math papers
        ds_train = load_dataset(
            "json",
            data_files="/content/drive/MyDrive/LLM_MIA/arxiv_mia_train_real.jsonl",
            split="train"
        )
        ds_train = filter_math_papers(ds_train)

        # 2. Load dev and test from huggingface and filter for math papers
        ds_dev = load_dataset("zhliu/ArxivMIA", "arxiv_mia_dev")["train"]
        ds_test = load_dataset("zhliu/ArxivMIA", "arxiv_mia_test")["train"]

        ds_dev = filter_math_papers(ds_dev)
        ds_test = filter_math_papers(ds_test)

    elif data_type == "arxiv_cs":
        # 1. Load local training data and filter for CS papers
        ds_train = load_dataset(
            "json",
            data_files="/content/drive/MyDrive/LLM_MIA/arxiv_mia_train_real.jsonl",
            split="train"
        )
        ds_train = filter_cs_papers(ds_train)

        # 2. Load dev and test from huggingface and filter for CS papers
        ds_dev = load_dataset("zhliu/ArxivMIA", "arxiv_mia_dev")["train"]
        ds_test = load_dataset("zhliu/ArxivMIA", "arxiv_mia_test")["train"]

        ds_dev = filter_cs_papers(ds_dev)
        ds_test = filter_cs_papers(ds_test)

    # Process for each threshold
    for threshold in thresholds:
        # Set the current threshold
        ACTIVATION_THRESHOLD = threshold

        # 1. Get member and non-member samples from the development set
        dev_member = [item for item in ds_dev if item['label'] == 1]
        dev_nonmember = [item for item in ds_dev if item['label'] == 0]

        # 2. Split the development set evenly into training and validation sets
        train_member = dev_member[:len(dev_member)//2]
        train_nonmember = dev_nonmember[:len(dev_nonmember)//2]

        val_member = dev_member[len(dev_member)//2:]
        val_nonmember = dev_nonmember[len(dev_nonmember)//2:]

        # 3. Collect activation data
        train_member_acts = collect_activations(train_member)
        train_nonmember_acts = collect_activations(train_nonmember)

        val_member_acts = collect_activations(val_member)
        val_nonmember_acts = collect_activations(val_nonmember)

        # 4. Build reference activation patterns
        reference_patterns = build_reference_patterns(
            train_member_acts + train_nonmember_acts
        )

        # 5. Calculate layer discrimination scores
        layer_scores = calculate_layer_discrimination_scores(reference_patterns)

        # 6. Select the most discriminative layers
        discriminative_layers = select_discriminative_layers(layer_scores, top_n=10)

        # 7. Find the best threshold on the validation set using the relative ratio method
        val_samples = val_member_acts + val_nonmember_acts
        best_threshold, val_metrics = find_best_relative_ratio_threshold(
            val_samples,
            reference_patterns,
            discriminative_layers
        )

        # 8. Calculate test set AUC
        test_auc = calculate_auc(list(ds_test), reference_patterns, discriminative_layers)

        # 9. Print only the dataset, threshold, and test set AUC value
        print(f"Dataset: {data_type}, Threshold: {threshold:.1f}, Test Set AUC = {test_auc:.4f}")

        # 10. Clean up memory
        del train_member_acts, train_nonmember_acts, val_member_acts, val_nonmember_acts
        del reference_patterns, layer_scores, discriminative_layers, val_samples
        gc.collect()
        torch.cuda.empty_cache()

    # Release memory after processing the current dataset
    del ds_train, ds_dev, ds_test
    gc.collect()
    torch.cuda.empty_cache()

# Main function
if __name__ == "__main__":
    try:
        # Set the activation thresholds to be evaluated
        activation_thresholds = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0]

        # Process the three datasets sequentially
        datasets = ["arxiv_all"]

        for dataset in datasets:
            process_dataset(dataset, thresholds=activation_thresholds)

        # After all datasets are processed, remove the hooks
        for hook in hooks:
            hook.remove()

    except Exception as e:
        print(f"An error occurred during execution: {str(e)}")
        # Ensure hooks are removed
        for hook in hooks:
            try:
                hook.remove()
            except:
                pass
        raise e

### Open_Llama_13b

In [None]:
# Import necessary libraries
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from collections import Counter
from scipy import stats
from datasets import load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, roc_curve, auc
import tqdm
from transformers import BitsAndBytesConfig
import gc

# Define a function that does nothing and returns its input
def dummy_tqdm(iterable=None, *args, **kwargs):
    return iterable if iterable is not None else dummy_tqdm

# Add all necessary attributes and methods
dummy_tqdm.format_interval = lambda x: f"{x:.1f}s"
dummy_tqdm.format_meter = lambda *args, **kwargs: ""
dummy_tqdm.format_num = lambda x: str(x)
dummy_tqdm.status_printer = lambda *args, **kwargs: lambda x: None
dummy_tqdm.get_lock = lambda: None
dummy_tqdm.set_lock = lambda x: None
dummy_tqdm.display = lambda *args, **kwargs: None
dummy_tqdm.clear = lambda *args, **kwargs: None
dummy_tqdm.close = lambda *args, **kwargs: None
dummy_tqdm.update = lambda *args, **kwargs: None
dummy_tqdm.refresh = lambda *args, **kwargs: None
dummy_tqdm.disable = True
dummy_tqdm.monitor_interval = 0
dummy_tqdm.monitor = None
dummy_tqdm.pos = 0
dummy_tqdm.__iter__ = lambda self: iter([])
dummy_tqdm.__next__ = lambda self: next(iter([]))

# Replace all tqdm variants
tqdm.tqdm = dummy_tqdm
tqdm.std.tqdm = dummy_tqdm
tqdm.notebook.tqdm = dummy_tqdm
tqdm.auto.tqdm = dummy_tqdm
tqdm.gui.tqdm = dummy_tqdm
tqdm.cli.tqdm = dummy_tqdm
tqdm.__call__ = dummy_tqdm

# --- Patch missing tqdm.format_sizeof --- #
def _dummy_format_sizeof(num_bytes, *args, **kwargs):
    """
    A fake tqdm.format_sizeof.
    Just returns a simple string of the byte count, enough to trick transformers,
    without affecting your complete disabling of tqdm.
    """
    return f"{num_bytes}"

# If you already have a dummy_tqdm object:
try:
    dummy_tqdm.format_sizeof = _dummy_format_sizeof
except NameError:
    pass  # Skip if dummy_tqdm does not exist

# Also patch the real tqdm module (or the one you replaced)
import sys
if 'tqdm' in sys.modules:
    setattr(sys.modules['tqdm'], 'format_sizeof', _dummy_format_sizeof)

# Set the list of frequency ratio thresholds to test
FREQ_THRESHOLDS = [1.2, 1.4, 1.5, 1.6, 1.8, 2.0]

# 1. Set up the environment
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 3. Load the model - changed to Open LLaMA 13B
model_name = "openlm-research/open_llama_13b"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,   # <- 16-bit BF16
    device_map="auto",            # <- Automatically split layers across GPU/CPU
    low_cpu_mem_usage=True        # <- Reduce peak CPU RAM usage, optional
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # If needed

# 4. Set up hook functions to collect neuron activations
activations = {}
activated_neurons = {}
ACTIVATION_THRESHOLD = 2

def get_ffn_activation(name):
    """Hook function to capture FFN activations"""
    def hook(module, input, output):
        # For the output of the SiLU activation function
        if output.dim() >= 2:
            # If it's a 3D tensor [batch_size, seq_len, hidden_dim]
            if output.dim() == 3:
                # Get only the indices after masking
                activation_mask = (output > ACTIVATION_THRESHOLD).detach().cpu()
                activated_neurons[name] = activation_mask
            # If it's a 2D tensor [batch_size, hidden_dim]
            elif output.dim() == 2:
                activation_mask = (output > ACTIVATION_THRESHOLD).detach().cpu()
                activated_neurons[name] = activation_mask
    return hook

# 5. Register hooks for the model
hooks = []

# Register hooks for the SiLU activation function in LlamaMLP
# According to the model structure, the activation function is located at model.layers[i].mlp.act_fn
for i, layer in enumerate(model.model.layers):
    # Register MLP activation function hook - directly using the act_fn attribute
    hook = layer.mlp.act_fn.register_forward_hook(get_ffn_activation(f'layer_{i}_mlp_act'))
    hooks.append(hook)

print(f"Registered hooks for the SiLU activation function in the {len(model.model.layers)}-layer model")


# 7. Collect sample activation data
def process_sample(sample, sample_id):
    """Process a single sample and collect only the index information of activated neurons"""
    # Clear previous activations
    activations.clear()
    activated_neurons.clear()

    # Prepare input for the sample
    encodings = tokenizer(sample['text'], return_tensors="pt", truncation=True, max_length=256)
    encodings = {k: v.to(device) for k, v in encodings.items()}

    # Run the model
    with torch.no_grad():
        outputs = model(**encodings)

    # Collect only the indices of activated neurons, do not store activation values
    sample_neural_signature = {}
    for key, value in activated_neurons.items():
        try:
            if len(value.shape) >= 2:
                mask = value.squeeze(0).numpy()

                # Only collect which neurons are activated in each layer, without recording position or activation value
                if len(mask.shape) == 2:  # [seq_len, hidden_dim]
                    # Merge the activation status across all positions
                    # Any neuron activated at any position is considered activated
                    active_neurons = set()
                    for pos in range(mask.shape[0]):
                        active_indices = np.where(mask[pos])[0]
                        active_neurons.update(active_indices)

                    if active_neurons:
                        sample_neural_signature[key] = list(active_neurons)
                elif len(mask.shape) == 1:  # [hidden_dim]
                    active_indices = np.where(mask)[0]
                    if len(active_indices) > 0:
                        sample_neural_signature[key] = active_indices.tolist()
        except Exception as e:
            pass

    # Return sample information, containing only the label and neural signature
    return {
        'sample_id': sample_id,
        'label': sample['label'],
        'neural_signature': sample_neural_signature
    }

def collect_activations(samples, batch_size=10):
    """Collect activation data for samples in batches"""
    results = []

    # Process samples in batches to save memory
    for i in range(0, len(samples), batch_size):
        batch = samples[i:i+batch_size]
        batch_results = []

        for j, sample in enumerate(batch):
            try:
                result = process_sample(sample, i + j)
                batch_results.append(result)
            except Exception as e:
                pass

        results.extend(batch_results)

        # Clean up memory
        if i + batch_size < len(samples):  # Not the last batch
            del batch_results
            gc.collect()
            torch.cuda.empty_cache()

    return results

# 8. Analyze neuron activation patterns - modified to accept frequency threshold parameter
def analyze_neuron_activation_patterns(member_neurons, nonmember_neurons, freq_threshold=1.5):
    """Analyze the activation patterns of neurons in member and non-member samples, using a configurable frequency ratio threshold"""
    results = {}

    # Get all layer names
    layer_names = set()
    for sample in member_neurons + nonmember_neurons:
        layer_names.update(sample['neural_signature'].keys())  # Modified here

    # Analyze layer by layer
    for layer_name in layer_names:
        # Count activated neurons in member and non-member samples
        member_neuron_counts = Counter()
        nonmember_neuron_counts = Counter()

        # Activated neurons in member samples
        for sample in member_neurons:
            if layer_name in sample['neural_signature']:  # Modified here
                member_neuron_counts.update(sample['neural_signature'][layer_name])  # Modified here

        # Activated neurons in non-member samples
        for sample in nonmember_neurons:
            if layer_name in sample['neural_signature']:  # Modified here
                nonmember_neuron_counts.update(sample['neural_signature'][layer_name])  # Modified here

        # Calculate the activation frequency for each neuron
        member_freq = {n: count / len(member_neurons) for n, count in member_neuron_counts.items()}
        nonmember_freq = {n: count / len(nonmember_neurons) for n, count in nonmember_neuron_counts.items()}

        # Identify neurons predominantly activated in member samples, using the passed frequency threshold
        member_dominant = {}
        for neuron, freq in member_freq.items():
            if neuron not in nonmember_freq or freq > nonmember_freq[neuron] * freq_threshold:
                member_dominant[neuron] = freq

        # Identify neurons predominantly activated in non-member samples, using the passed frequency threshold
        nonmember_dominant = {}
        for neuron, freq in nonmember_freq.items():
            if neuron not in member_freq or freq > member_freq[neuron] * freq_threshold:
                nonmember_dominant[neuron] = freq

        # Identify neurons frequently activated in both types of samples
        common_neurons = {}
        for neuron in set(member_freq.keys()) & set(nonmember_freq.keys()):
            if neuron not in member_dominant and neuron not in nonmember_dominant:
                common_neurons[neuron] = (member_freq[neuron], nonmember_freq[neuron])

        # Save the results
        results[layer_name] = {
            'member_dominant': member_dominant,
            'nonmember_dominant': nonmember_dominant,
            'common_neurons': common_neurons,
            'member_counts': member_neuron_counts,
            'nonmember_counts': nonmember_neuron_counts,
            'member_freq': member_freq,
            'nonmember_freq': nonmember_freq
        }

    return results

# 9. Build reference patterns - modified to accept frequency threshold parameter
def build_reference_patterns(train_samples, freq_threshold=1.5, validation=False):
    """Build reference activation patterns using training samples, with a configurable frequency ratio threshold"""
    # Separate member and non-member samples
    member_samples = [s for s in train_samples if s['label'] == 1]
    nonmember_samples = [s for s in train_samples if s['label'] == 0]

    print(f"Building reference patterns: {len(member_samples)} member samples, {len(nonmember_samples)} non-member samples")

    # Analyze activation pattern differences, passing the frequency threshold
    reference_patterns = analyze_neuron_activation_patterns(member_samples, nonmember_samples, freq_threshold)

    # If in validation phase, output reference pattern statistics
    if validation:
        for layer_name, data in reference_patterns.items():
            print(f"\n{layer_name} Activation Pattern Analysis:")
            print(f"  Neurons predominantly activated in member samples: {len(data['member_dominant'])}")
            print(f"  Neurons predominantly activated in non-member samples: {len(data['nonmember_dominant'])}")
            print(f"  Neurons activated in both types of samples: {len(data['common_neurons'])}")

    return reference_patterns

# 10. Calculate discrimination score for each layer
def calculate_layer_discrimination_scores(reference_patterns):
    """Calculate the discriminative power score for each layer"""
    layer_scores = {}

    # Calculate the discriminative power score for each layer
    for layer_name, data in reference_patterns.items():
        # Simple discrimination score: number of member-dominant neurons minus number of non-member-dominant neurons
        # If the score is greater than 0, it indicates the layer is more inclined to identify member samples
        member_dominant_count = len(data['member_dominant'])
        nonmember_dominant_count = len(data['nonmember_dominant'])

        discrimination_score = member_dominant_count - nonmember_dominant_count
        layer_scores[layer_name] = discrimination_score

    # Print layer discrimination scores
    print("\nLayer Discrimination Scores:")
    sorted_scores = sorted(layer_scores.items(), key=lambda x: x[1], reverse=True)
    for layer_name, score in sorted_scores:
        print(f"  {layer_name}: {score}")

    return layer_scores

# 11. Select the most discriminative layers
def select_discriminative_layers(layer_scores, top_n=10):
    """Select the most discriminative layers"""
    # Sort by discrimination score
    sorted_scores = sorted(layer_scores.items(), key=lambda x: abs(x[1]), reverse=True)

    # Select the top N layers
    selected_layers = [layer for layer, score in sorted_scores[:top_n]]

    print("Selected most discriminative layers:")
    for layer_name, score in sorted_scores[:top_n]:
        print(f"  {layer_name}: Score {score}")

    return selected_layers

# 12. Membership prediction based on relative ratio
def predict_membership_by_relative_ratio(sample, reference_patterns, discriminative_layers, threshold=1.0):
    """Predict whether a single sample is a member using the relative ratio method"""
    if not sample['neural_signature']:  # Modified here
        return 0, 0, 0  # If there is no activation data, default to non-member

    # Initialize relative ratios
    layers_counted = 0
    total_member_ratio = 0
    total_nonmember_ratio = 0

    # Analyze layer by layer
    for layer_name in discriminative_layers:
        if layer_name not in sample['neural_signature'] or layer_name not in reference_patterns:  # Modified here
            continue

        # Get the reference pattern for this layer
        layer_data = reference_patterns[layer_name]

        # Get all neurons activated by this sample in this layer
        sample_neurons = set(sample['neural_signature'][layer_name])  # Modified here

        if not sample_neurons:
            continue

        # Calculate the relative overlap with member-dominant neurons
        member_dominant_set = set(layer_data['member_dominant'].keys())
        member_overlap = len(sample_neurons.intersection(member_dominant_set))

        # Calculate the relative overlap with non-member-dominant neurons
        nonmember_dominant_set = set(layer_data['nonmember_dominant'].keys())
        nonmember_overlap = len(sample_neurons.intersection(nonmember_dominant_set))

        # Calculate relative ratios
        member_ratio = member_overlap / len(member_dominant_set) if len(member_dominant_set) > 0 else 0
        nonmember_ratio = nonmember_overlap / len(nonmember_dominant_set) if len(nonmember_dominant_set) > 0 else 0

        # Accumulate relative ratios
        total_member_ratio += member_ratio
        total_nonmember_ratio += nonmember_ratio
        layers_counted += 1

    # If there are no valid layers, default to non-member
    if layers_counted == 0:
        return 0, 0, 0

    # Calculate average relative ratios
    avg_member_ratio = total_member_ratio / layers_counted
    avg_nonmember_ratio = total_nonmember_ratio / layers_counted

    # Calculate the ratio of relative ratios
    if avg_nonmember_ratio == 0:
        ratio = float('inf')  # Avoid division by zero
    else:
        ratio = avg_member_ratio / avg_nonmember_ratio

    # Return the prediction and related ratios
    return 1 if ratio >= threshold else 0, avg_member_ratio, avg_nonmember_ratio

# 14. Find the best threshold on the validation set
def find_best_relative_ratio_threshold(val_samples, reference_patterns, discriminative_layers):
    """Find the best threshold on the validation set using the relative ratio method"""
    # Calculate relative ratios for all validation samples
    ratios = []
    labels = []

    for sample in val_samples:
        pred, member_ratio, nonmember_ratio = predict_membership_by_relative_ratio(
            sample, reference_patterns, discriminative_layers, threshold=1.0
        )

        # Calculate the ratio of relative ratios
        ratio = float('inf') if nonmember_ratio == 0 else member_ratio / nonmember_ratio

        ratios.append(ratio)
        labels.append(sample['label'])

    # Filter out infinite values
    filtered_ratios = []
    filtered_labels = []
    for r, l in zip(ratios, labels):
        if r != float('inf') and not np.isnan(r):
            filtered_ratios.append(r)
            filtered_labels.append(l)

    # Create candidate thresholds
    min_ratio = min(filtered_ratios)
    max_ratio = max(filtered_ratios)

    # Generate uniformly distributed thresholds
    candidate_thresholds = list(np.linspace(min_ratio, max_ratio, 100))
    # Add some important threshold points
    candidate_thresholds += [0.5, 1.0, 1.5, 2.0, 3.0, 5.0]
    candidate_thresholds = sorted(set(candidate_thresholds))

    # Find the best threshold
    best_threshold = 1.0
    best_accuracy = 0
    best_metrics = None

    results = []
    for threshold in candidate_thresholds:
        # Make predictions using the current threshold
        predictions = [1 if r >= threshold else 0 for r in ratios]

        # Calculate performance metrics
        accuracy = accuracy_score(labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')

        # Record the results
        results.append({
            'threshold': threshold,
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1
        })

        # Update the best threshold
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_threshold = threshold
            best_metrics = {
                'accuracy': accuracy,
                'precision': precision,
                'recall': recall,
                'f1': f1
            }

    print(f"Best relative ratio threshold: {best_threshold:.4f}")
    print(f"Best accuracy: {best_accuracy:.4f}")
    print(f"Precision: {best_metrics['precision']:.4f}")
    print(f"Recall: {best_metrics['recall']:.4f}")
    print(f"F1 Score: {best_metrics['f1']:.4f}")

    return best_threshold, best_metrics

# Simplified function to test different frequency thresholds, returns only AUC values
def run_frequency_threshold_sensitivity_test():
    """Run a sensitivity analysis for different frequency ratio thresholds, returning only AUC values"""
    try:
        from random import seed
        seed(42)  # Ensure reproducibility

        # Frequency thresholds to test
        freq_thresholds = FREQ_THRESHOLDS
        print(f"Starting frequency ratio threshold sensitivity analysis...")
        print(f"Frequency ratio thresholds to be tested: {freq_thresholds}")
        print(f"Fixed activation threshold: {ACTIVATION_THRESHOLD}")

        # Store AUC results for different thresholds
        threshold_aucs = {}

        # Get member and non-member samples from the development set
        dev_member = [item for item in ds_dev if item['label'] == 1]
        dev_nonmember = [item for item in ds_dev if item['label'] == 0]

        # Split the development set into 80/20 for training and validation
        train_size_member = int(len(dev_member) * 0.8)
        train_size_nonmember = int(len(dev_nonmember) * 0.8)

        train_member = dev_member[:train_size_member]
        train_nonmember = dev_nonmember[:train_size_nonmember]

        val_member = dev_member[train_size_member:]
        val_nonmember = dev_nonmember[train_size_nonmember:]

        print(f"Training set: {len(train_member)} member samples, {len(train_nonmember)} non-member samples")
        print(f"Validation set: {len(val_member)} member samples, {len(val_nonmember)} non-member samples")

        # Collect activation data for training and validation samples
        print("\nCollecting training sample activations...")
        train_member_acts = collect_activations(train_member)
        train_nonmember_acts = collect_activations(train_nonmember)

        print("\nCollecting validation sample activations...")
        val_member_acts = collect_activations(val_member)
        val_nonmember_acts = collect_activations(val_nonmember)

        val_samples = val_member_acts + val_nonmember_acts

        # Test for each frequency ratio threshold
        for freq_threshold in freq_thresholds:
            print(f"\n===== Testing frequency ratio threshold: {freq_threshold} =====")

            # Build reference activation patterns using the current frequency ratio threshold
            print(f"\nBuilding reference activation patterns (frequency threshold={freq_threshold})...")
            reference_patterns = build_reference_patterns(
                train_member_acts + train_nonmember_acts,
                freq_threshold=freq_threshold,
                validation=True
            )

            # Calculate layer discrimination scores
            print("\nCalculating layer discrimination scores...")
            layer_scores = calculate_layer_discrimination_scores(reference_patterns)

            # Select the most discriminative layers
            print("\nSelecting the most discriminative layers...")
            discriminative_layers = select_discriminative_layers(layer_scores, top_n=10)

            # Find the best relative ratio threshold on the validation set
            print("\nFinding the best threshold on the validation set using the relative ratio method...")
            best_threshold, val_metrics = find_best_relative_ratio_threshold(
                val_samples,
                reference_patterns,
                discriminative_layers
            )

            # Calculate relative ratios for all validation samples
            ratios = []
            labels = []

            for sample in val_samples:
                pred, member_ratio, nonmember_ratio = predict_membership_by_relative_ratio(
                    sample, reference_patterns, discriminative_layers, threshold=best_threshold
                )

                # Calculate the ratio of relative ratios
                ratio = float('inf') if nonmember_ratio == 0 else member_ratio / nonmember_ratio

                ratios.append(ratio)
                labels.append(sample['label'])

            # Filter out infinite values
            filtered_ratios = []
            filtered_labels = []
            for r, l in zip(ratios, labels):
                if r != float('inf') and not np.isnan(r):
                    filtered_ratios.append(r)
                    filtered_labels.append(l)

            # Calculate AUC
            fpr, tpr, _ = roc_curve(filtered_labels, filtered_ratios)
            roc_auc = auc(fpr, tpr)

            # Record the AUC result for the current frequency threshold
            threshold_aucs[freq_threshold] = roc_auc

            print(f"AUC for frequency threshold {freq_threshold}: {roc_auc:.4f}")

        # Find the best frequency threshold
        best_freq_threshold = max(freq_thresholds, key=lambda t: threshold_aucs[t])

        print("\n===== Frequency Ratio Threshold Sensitivity Analysis Results =====")
        print(f"Best frequency ratio threshold: {best_freq_threshold}")
        print(f"Best AUC: {threshold_aucs[best_freq_threshold]:.4f}")

        # Data in table format
        print("\nAUC Value Table:")
        print("Freq. Threshold | AUC")
        print("----------------|-----")
        for freq in freq_thresholds:
            print(f"{freq:.1f}             | {threshold_aucs[freq]:.4f}")

        return threshold_aucs, best_freq_threshold

    except Exception as e:
        print(f"An error occurred during execution: {str(e)}")
        # Ensure hooks are removed
        for hook in hooks:
            try:
                hook.remove()
            except:
                pass
        raise e


# Run Membership Inference Attack
if __name__ == "__main__":
    # Run frequency threshold sensitivity test
    threshold_aucs, best_freq_threshold = run_frequency_threshold_sensitivity_test()
    print(f"\nFrequency threshold sensitivity analysis complete! The best frequency ratio threshold is: {best_freq_threshold}")

    # Clean up resources
    for hook in hooks:
        hook.remove()

    print("Program execution finished!")

## CCNewsPDD

### Pythia-2.8b

In [None]:
# Import necessary libraries
import os
import torch
import numpy as np
from collections import Counter
from datasets import load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics import roc_curve, auc, accuracy_score, precision_recall_fscore_support
import gc
import warnings
warnings.filterwarnings("ignore")

# Disable tqdm - a safer way
import sys
import importlib
# Completely disable tqdm
import tqdm

# Define a function that does nothing and returns its input
def dummy_tqdm(iterable=None, *args, **kwargs):
    return iterable if iterable is not None else dummy_tqdm

# Add all necessary attributes and methods
dummy_tqdm.format_interval = lambda x: f"{x:.1f}s"
dummy_tqdm.format_meter = lambda *args, **kwargs: ""
dummy_tqdm.format_num = lambda x: str(x)
dummy_tqdm.status_printer = lambda *args, **kwargs: lambda x: None
dummy_tqdm.get_lock = lambda: None
dummy_tqdm.set_lock = lambda x: None
dummy_tqdm.display = lambda *args, **kwargs: None
dummy_tqdm.clear = lambda *args, **kwargs: None
dummy_tqdm.close = lambda *args, **kwargs: None
dummy_tqdm.update = lambda *args, **kwargs: None
dummy_tqdm.refresh = lambda *args, **kwargs: None
dummy_tqdm.disable = True
dummy_tqdm.monitor_interval = 0
dummy_tqdm.monitor = None
dummy_tqdm.pos = 0
dummy_tqdm.__iter__ = lambda self: iter([])
dummy_tqdm.__next__ = lambda self: next(iter([]))

# Replace all tqdm variants
tqdm.tqdm = dummy_tqdm
tqdm.std.tqdm = dummy_tqdm
tqdm.notebook.tqdm = dummy_tqdm
tqdm.auto.tqdm = dummy_tqdm
tqdm.gui.tqdm = dummy_tqdm
tqdm.cli.tqdm = dummy_tqdm
tqdm.__call__ = dummy_tqdm

# --- Patch missing tqdm.format_sizeof --- #
def _dummy_format_sizeof(num_bytes, *args, **kwargs):
    """
    A fake tqdm.format_sizeof.
    Just returns a simple string of the byte count, enough to trick transformers,
    without affecting your complete disabling of tqdm.
    """
    return f"{num_bytes}"

# If you already have a dummy_tqdm object:
try:
    dummy_tqdm.format_sizeof = _dummy_format_sizeof
except NameError:
    pass  # Skip if dummy_tqdm does not exist

# Also patch the real tqdm module (or the one you replaced)
import sys
if 'tqdm' in sys.modules:
    setattr(sys.modules['tqdm'], 'format_sizeof', _dummy_format_sizeof)

# 1. Set up the environment
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Core Change 1: Switch model to pythia-2.8b ---
model_name = "EleutherAI/pythia-2.8b"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,  # Specify 16-bit precision
    device_map="auto"           # Automatically manage model distribution on GPU/CPU
)
# --- End of change ---

# Load the corresponding tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Ensure the model is in evaluation mode
model.eval()

# If the tokenizer does not have a padding token, set one
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Model loaded: {model_name}, using {model.dtype} precision")
print(f"Model is on device: {next(model.parameters()).device}")

# Display model memory usage
def get_model_size(model):
    """Calculate model size (GB)"""
    model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**3
    return model_size

model_size_gb = get_model_size(model)
print(f"Model size: {model_size_gb:.2f} GB")

# 4. Set up hook functions to collect neuron activations
activations = {}
activated_neurons = {}
ACTIVATION_THRESHOLD = 0  # Set activation threshold

def get_ffn_activation(name):
    """Hook function to capture FFN activations"""
    def hook(module, input, output):
        # Save activation values
        activations[name] = output[0].detach().cpu() # Pythia FFN output is a tuple

        # Identify activated neurons (neurons exceeding the threshold)
        act_output = output[0]
        if act_output.dim() >= 2:
            if act_output.dim() == 3: # [batch_size, seq_len, hidden_dim]
                activation_mask = (act_output > ACTIVATION_THRESHOLD).detach().cpu()
                activated_neurons[name] = activation_mask
            elif act_output.dim() == 2: # [batch_size, hidden_dim]
                activation_mask = (act_output > ACTIVATION_THRESHOLD).detach().cpu()
                activated_neurons[name] = activation_mask
    return hook

# 5. Register hooks for the model
hooks = []

# --- Core Change 2: Adapt to Pythia/GPT-NeoX model structure ---
# Analyze the structure of the first layer to understand the location of the activation function
sample_layer = model.gpt_neox.layers[0]
print("Layer Structure Analysis (Pythia/GPT-NeoX):")
for name, module in sample_layer.named_modules():
    print(f"  - {name}: {type(module).__name__}")

# Register hooks for FFN layers - targeting the Pythia/GPT-NeoX model architecture
for i, layer in enumerate(model.gpt_neox.layers):
    # The activation function in Pythia/GPT-NeoX is usually at layer.mlp.act
    try:
        module_to_hook = layer.mlp.act
        print(f"Registering hook for mlp.act module ({type(module_to_hook).__name__}) in layer {i}")
        hook = module_to_hook.register_forward_hook(get_ffn_activation(f'layer_{i}_mlp_act'))
        hooks.append(hook)
    except AttributeError:
        print(f"Warning: mlp.act module not found in layer {i}")
# --- End of change ---

# 7. Collect sample activation data
def process_sample(sample, sample_id):
    """Process a single sample and collect activations"""
    # Clear previous activations
    activations.clear()
    activated_neurons.clear()

    # Prepare input for the sample
    encodings = tokenizer(sample['text'], return_tensors="pt", truncation=True, max_length=512)
    encodings = {k: v.to(model.device) for k, v in encodings.items()}

    # Run the model
    with torch.no_grad():
        outputs = model(**encodings)

    # Collect activated neurons
    sample_activated_neurons = {}
    for key, value in activated_neurons.items():
        try:
            if len(value.shape) >= 2:
                # Remove batch dimension
                mask = value.squeeze(0).numpy()

                # For each position, record the indices of activated neurons
                if len(mask.shape) == 2:  # [seq_len, hidden_dim]
                    position_neurons = {}
                    for pos in range(mask.shape[0]):
                        active_indices = np.where(mask[pos])[0]
                        if len(active_indices) > 0:
                            position_neurons[pos] = active_indices.tolist()
                    if position_neurons:
                        sample_activated_neurons[key] = position_neurons
                elif len(mask.shape) == 1:  # [hidden_dim]
                    active_indices = np.where(mask)[0]
                    if len(active_indices) > 0:
                        sample_activated_neurons[key] = {0: active_indices.tolist()}
        except Exception as e:
            pass

    # Return sample information
    return {
        'sample_id': sample_id,
        'text': sample['text'],
        'label': sample['label'],
        'activated_neurons': sample_activated_neurons,
        'input_ids': encodings['input_ids'][0].cpu().numpy(),
    }

def collect_activations(samples, batch_size=10):
    """Collect activation data for samples in batches"""
    results = []
    # Use a simple for loop instead of tqdm
    for i in range(0, len(samples), batch_size):
        batch = samples[i:i+batch_size]
        batch_results = []
        for j, sample in enumerate(batch):
            try:
                result = process_sample(sample, i + j)
                batch_results.append(result)
            except Exception as e:
                pass
        results.extend(batch_results)
        if i + batch_size < len(samples):
            del batch_results
            gc.collect()
            torch.cuda.empty_cache()
    return results

# 8. Analyze neuron activation patterns
def analyze_neuron_activation_patterns(member_neurons, nonmember_neurons):
    """Analyze the activation patterns of neurons in member and non-member samples"""
    results = {}
    layer_names = set()
    for sample in member_neurons + nonmember_neurons:
        layer_names.update(sample['activated_neurons'].keys())

    for layer_name in layer_names:
        member_neuron_counts = Counter()
        nonmember_neuron_counts = Counter()
        for sample in member_neurons:
            if layer_name in sample['activated_neurons']:
                for pos_neurons in sample['activated_neurons'][layer_name].values():
                    member_neuron_counts.update(pos_neurons)
        for sample in nonmember_neurons:
            if layer_name in sample['activated_neurons']:
                for pos_neurons in sample['activated_neurons'][layer_name].values():
                    nonmember_neuron_counts.update(pos_neurons)

        member_freq = {n: count / len(member_neurons) for n, count in member_neuron_counts.items()} if member_neurons else {}
        nonmember_freq = {n: count / len(nonmember_neurons) for n, count in nonmember_neuron_counts.items()} if nonmember_neurons else {}

        member_dominant = {n: f for n, f in member_freq.items() if n not in nonmember_freq or f > nonmember_freq.get(n, 0) * 1.5}
        nonmember_dominant = {n: f for n, f in nonmember_freq.items() if n not in member_freq or f > member_freq.get(n, 0) * 1.5}
        common_neurons = {n: (member_freq[n], nonmember_freq[n]) for n in set(member_freq.keys()) & set(nonmember_freq.keys()) if n not in member_dominant and n not in nonmember_dominant}

        results[layer_name] = {
            'member_dominant': member_dominant, 'nonmember_dominant': nonmember_dominant,
            'common_neurons': common_neurons, 'member_counts': member_neuron_counts,
            'nonmember_counts': nonmember_neuron_counts, 'member_freq': member_freq,
            'nonmember_freq': nonmember_freq
        }
    return results

# 9. Build reference patterns
def build_reference_patterns(train_samples):
    """Build reference activation patterns using training samples"""
    member_samples = [s for s in train_samples if s['label'] == 1]
    nonmember_samples = [s for s in train_samples if s['label'] == 0]
    return analyze_neuron_activation_patterns(member_samples, nonmember_samples)

# 10. Calculate discrimination score for each layer
def calculate_layer_discrimination_scores(reference_patterns):
    """Calculate the discriminative power score for each layer"""
    return {
        layer_name: len(data['member_dominant']) - len(data['nonmember_dominant'])
        for layer_name, data in reference_patterns.items()
    }

# 11. Select the most discriminative layers
def select_discriminative_layers(layer_scores, top_n=10):
    """Select the most discriminative layers"""
    sorted_scores = sorted(layer_scores.items(), key=lambda x: abs(x[1]), reverse=True)
    return [layer for layer, score in sorted_scores[:top_n]]

# 12. Membership prediction based on relative ratio
def predict_membership_by_relative_ratio(sample, reference_patterns, discriminative_layers, threshold=1.0):
    """Predict whether a single sample is a member using the relative ratio method"""
    if not sample['activated_neurons']:
        return 0, 0, 0
    layers_counted = 0
    total_member_ratio = 0
    total_nonmember_ratio = 0
    for layer_name in discriminative_layers:
        if layer_name not in sample['activated_neurons'] or layer_name not in reference_patterns:
            continue
        layer_data = reference_patterns[layer_name]
        sample_neurons = set()
        for pos_neurons in sample['activated_neurons'][layer_name].values():
            sample_neurons.update(pos_neurons)
        if not sample_neurons:
            continue
        member_dominant_set = set(layer_data['member_dominant'].keys())
        nonmember_dominant_set = set(layer_data['nonmember_dominant'].keys())
        member_overlap = len(sample_neurons.intersection(member_dominant_set))
        nonmember_overlap = len(sample_neurons.intersection(nonmember_dominant_set))
        member_ratio = member_overlap / len(member_dominant_set) if len(member_dominant_set) > 0 else 0
        nonmember_ratio = nonmember_overlap / len(nonmember_dominant_set) if len(nonmember_dominant_set) > 0 else 0
        total_member_ratio += member_ratio
        total_nonmember_ratio += nonmember_ratio
        layers_counted += 1
    if layers_counted == 0:
        return 0, 0, 0
    avg_member_ratio = total_member_ratio / layers_counted
    avg_nonmember_ratio = total_nonmember_ratio / layers_counted
    ratio = float('inf') if avg_nonmember_ratio == 0 else avg_member_ratio / avg_nonmember_ratio
    return 1 if ratio >= threshold else 0, avg_member_ratio, avg_nonmember_ratio

# 14. Find the best threshold on the validation set
def find_best_relative_ratio_threshold(val_samples, reference_patterns, discriminative_layers):
    """Find the best threshold on the validation set using the relative ratio method"""
    ratios, labels = [], []
    for sample in val_samples:
        _, member_ratio, nonmember_ratio = predict_membership_by_relative_ratio(sample, reference_patterns, discriminative_layers)
        ratio = float('inf') if nonmember_ratio == 0 else member_ratio / nonmember_ratio
        ratios.append(ratio)
        labels.append(sample['label'])

    filtered_ratios = [r for r in ratios if r != float('inf') and not np.isnan(r)]
    if not filtered_ratios:
        return 1.0, {'accuracy': 0.5, 'precision': 0, 'recall': 0, 'f1': 0}

    candidate_thresholds = sorted(set(list(np.linspace(min(filtered_ratios), max(filtered_ratios), 100)) + [0.5, 1.0, 1.5, 2.0, 3.0, 5.0]))
    best_accuracy, best_threshold, best_metrics = 0, 1.0, None
    for threshold in candidate_thresholds:
        predictions = [1 if r >= threshold else 0 for r in ratios]
        accuracy = accuracy_score(labels, predictions)
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_threshold = threshold
            precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary', zero_division=0)
            best_metrics = {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}
    return best_threshold, best_metrics

# Calculate AUC and TPR@5%FPR for the test set
def calculate_auc_and_tpr(test_samples, reference_patterns, discriminative_layers, batch_size=10):
    """Calculate the AUC value and TPR@5%FPR for the test set"""
    all_ratios, all_labels = [], []
    for i in range(0, len(test_samples), batch_size):
        batch = test_samples[i:i+batch_size]
        batch_acts = collect_activations(batch, batch_size=batch_size)
        for sample in batch_acts:
            _, member_ratio, nonmember_ratio = predict_membership_by_relative_ratio(sample, reference_patterns, discriminative_layers)
            ratio = float('inf') if nonmember_ratio == 0 else member_ratio / nonmember_ratio
            all_ratios.append(ratio)
            all_labels.append(sample['label'])
        del batch_acts
        gc.collect()
        torch.cuda.empty_cache()

    filtered_ratios, filtered_labels = [], []
    for ratio, label in zip(all_ratios, all_labels):
        if ratio != float('inf') and not np.isnan(ratio):
            filtered_ratios.append(ratio)
            filtered_labels.append(label)
    if len(set(filtered_labels)) < 2:
        return 0.5, 0.0

    fpr, tpr, _ = roc_curve(filtered_ratios, filtered_labels)
    roc_auc = auc(fpr, tpr)
    tpr_at_fpr = 0.0
    try:
        target_index = np.where(fpr >= 0.05)[0][0]
        tpr_at_fpr = tpr[target_index]
    except IndexError:
        pass
    return roc_auc, tpr_at_fpr

# Process all thresholds for a single dataset
def process_dataset(data_type, data_dir="/content/drive/MyDrive/LLM_MIA/data", thresholds=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0]):
    """Process all thresholds for a single dataset"""
    global ACTIVATION_THRESHOLD
    print(f"\n===== Dataset: {data_type} =====")

    # Load dev and test based on the data_type condition
    if data_type == "prompt":
        dev_path = os.path.join(data_dir, "pile_cc_mia_dev_prompt_gpt4o_simple_v1")
        test_path = os.path.join(data_dir, "pile_cc_mia_test_prompt_gpt4o_simple_v1")
        print(f"Loading specific datasets for 'prompt' type:")
        print(f"  - DEV: {dev_path}")
        print(f"  - TEST: {test_path}")
        ds_dev = load_from_disk(dev_path)
        ds_test = load_from_disk(test_path)
    else:
        ds_dev = load_from_disk(os.path.join(data_dir, f"pile_cc_mia_dev_{data_type}"))
        ds_test = load_from_disk(os.path.join(data_dir, f"pile_cc_mia_test_{data_type}"))

    for threshold in thresholds:
        ACTIVATION_THRESHOLD = threshold
        dev_member = [item for item in ds_dev if item['label'] == 1]
        dev_nonmember = [item for item in ds_dev if item['label'] == 0]
        train_member, val_member = dev_member[:len(dev_member)//2], dev_member[len(dev_member)//2:]
        train_nonmember, val_nonmember = dev_nonmember[:len(dev_nonmember)//2], dev_nonmember[len(dev_nonmember)//2:]

        train_member_acts = collect_activations(train_member)
        train_nonmember_acts = collect_activations(train_nonmember)
        reference_patterns = build_reference_patterns(train_member_acts + train_nonmember_acts)

        layer_scores = calculate_layer_discrimination_scores(reference_patterns)
        discriminative_layers = select_discriminative_layers(layer_scores, top_n=10)

        # (Optional) Find the best threshold on the validation set - this part's result is unused in your code, so it can be simplified
        # val_member_acts = collect_activations(val_member)
        # val_nonmember_acts = collect_activations(val_nonmember)
        # val_samples = val_member_acts + val_nonmember_acts
        # best_threshold, val_metrics = find_best_relative_ratio_threshold(
        #     val_samples, reference_patterns, discriminative_layers
        # )

        test_auc, test_tpr_at_fpr = calculate_auc_and_tpr(list(ds_test), reference_patterns, discriminative_layers)
        print(f"Dataset: {data_type}, Activation Threshold: {threshold:.1f}, Test Set AUC = {test_auc:.4f}, TPR@5%FPR (%) = {test_tpr_at_fpr * 100:.2f}%")

        del train_member_acts, train_nonmember_acts, reference_patterns, layer_scores, discriminative_layers
        gc.collect()
        torch.cuda.empty_cache()

    del ds_dev, ds_test
    gc.collect()
    torch.cuda.empty_cache()

# Main function
if __name__ == "__main__":
    try:
        activation_thresholds = [1.0]
        datasets = ["bt", "bert", "prompt"]
        for dataset in datasets:
            process_dataset(dataset, thresholds=activation_thresholds)
        for hook in hooks:
            hook.remove()
    except Exception as e:
        print(f"An error occurred during execution: {str(e)}")
        for hook in hooks:
            try: hook.remove()
            except: pass
        raise e

### Opt-6.7b

In [None]:
# Import necessary libraries
import os
import torch
import numpy as np
from collections import Counter
from datasets import load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics import roc_curve, auc, accuracy_score, precision_recall_fscore_support
import gc
import warnings
warnings.filterwarnings("ignore")

# Disable tqdm - a safer way
import sys
import importlib
# Completely disable tqdm
import tqdm

# Define a function that does nothing and returns its input
def dummy_tqdm(iterable=None, *args, **kwargs):
    return iterable if iterable is not None else dummy_tqdm

# Add all necessary attributes and methods
dummy_tqdm.format_interval = lambda x: f"{x:.1f}s"
dummy_tqdm.format_meter = lambda *args, **kwargs: ""
dummy_tqdm.format_num = lambda x: str(x)
dummy_tqdm.status_printer = lambda *args, **kwargs: lambda x: None
dummy_tqdm.get_lock = lambda: None
dummy_tqdm.set_lock = lambda x: None
dummy_tqdm.display = lambda *args, **kwargs: None
dummy_tqdm.clear = lambda *args, **kwargs: None
dummy_tqdm.close = lambda *args, **kwargs: None
dummy_tqdm.update = lambda *args, **kwargs: None
dummy_tqdm.refresh = lambda *args, **kwargs: None
dummy_tqdm.disable = True
dummy_tqdm.monitor_interval = 0
dummy_tqdm.monitor = None
dummy_tqdm.pos = 0
dummy_tqdm.__iter__ = lambda self: iter([])
dummy_tqdm.__next__ = lambda self: next(iter([]))

# Replace all tqdm variants
tqdm.tqdm = dummy_tqdm
tqdm.std.tqdm = dummy_tqdm
tqdm.notebook.tqdm = dummy_tqdm
tqdm.auto.tqdm = dummy_tqdm
tqdm.gui.tqdm = dummy_tqdm
tqdm.cli.tqdm = dummy_tqdm
tqdm.__call__ = dummy_tqdm

# --- Patch missing tqdm.format_sizeof --- #
def _dummy_format_sizeof(num_bytes, *args, **kwargs):
    """
    A fake tqdm.format_sizeof.
    Just returns a simple string of the byte count, enough to trick transformers,
    without affecting your complete disabling of tqdm.
    """
    return f"{num_bytes}"

# If you already have a dummy_tqdm object:
try:
    dummy_tqdm.format_sizeof = _dummy_format_sizeof
except NameError:
    pass  # Skip if dummy_tqdm does not exist

# Also patch the real tqdm module (or the one you replaced)
import sys
if 'tqdm' in sys.modules:
    setattr(sys.modules['tqdm'], 'format_sizeof', _dummy_format_sizeof)

# 1. Set up the environment
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the OPT-6.7B model with 16-bit precision (half-precision, FP16)
model_name = "facebook/opt-6.7b"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,  # Specify 16-bit precision
    device_map="auto"           # Automatically manage model distribution on GPU/CPU
)

# Load the corresponding tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Ensure the model is in evaluation mode
model.eval()

# If the tokenizer does not have a padding token, set one
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Model loaded, using {model.dtype} precision")
print(f"Model is on device: {next(model.parameters()).device}")

# Display model memory usage
def get_model_size(model):
    """Calculate model size (GB)"""
    model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**3
    return model_size

model_size_gb = get_model_size(model)
print(f"Model size: {model_size_gb:.2f} GB")

# If the tokenizer does not have a padding token, set one
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# 4. Set up hook functions to collect neuron activations
activations = {}
activated_neurons = {}
ACTIVATION_THRESHOLD = 0  # Set activation threshold

def get_ffn_activation(name):
    """Hook function to capture FFN activations"""
    def hook(module, input, output):
        # Save activation values
        activations[name] = output.detach().cpu()

        # Identify activated neurons (neurons exceeding the threshold)
        if output.dim() >= 2:
            # If it's a 3D tensor [batch_size, seq_len, hidden_dim]
            if output.dim() == 3:
                # For each sample and each position, find neurons with activation exceeding the threshold
                activation_mask = (output > ACTIVATION_THRESHOLD).detach().cpu()
                activated_neurons[name] = activation_mask
            # If it's a 2D tensor [batch_size, hidden_dim]
            elif output.dim() == 2:
                activation_mask = (output > ACTIVATION_THRESHOLD).detach().cpu()
                activated_neurons[name] = activation_mask
    return hook

# 5. Register hooks for the model
hooks = []

# Analyze the structure of the first layer to understand the location of the activation function
sample_layer = model.model.decoder.layers[0]
print("Layer Structure Analysis:")
for name, module in sample_layer.named_modules():
    print(f"  - {name}: {type(module).__name__}")

# Register hooks for FFN layers - targeting the OPT model architecture
for i, layer in enumerate(model.model.decoder.layers):
    # Find the activation function in the OPT model
    activation_found = False

    # Find activation functions in all modules
    for name, module in layer.named_modules():
        # Look for GELU or ReLU activation functions
        if isinstance(module, torch.nn.GELU) or isinstance(module, torch.nn.ReLU):
            print(f"Registering hook for module {name} in layer {i}")
            hook = module.register_forward_hook(get_ffn_activation(f'layer_{i}_{name}'))
            hooks.append(hook)
            activation_found = True

    if not activation_found:
        print(f"Warning: Activation function not found in layer {i}")

# 7. Collect sample activation data
def process_sample(sample, sample_id):
    """Process a single sample and collect activations"""
    # Clear previous activations
    activations.clear()
    activated_neurons.clear()

    # Prepare input for the sample
    encodings = tokenizer(sample['text'], return_tensors="pt", truncation=True, max_length=512)
    encodings = {k: v.to(model.device) for k, v in encodings.items()} # Corrected to model.device

    # Run the model
    with torch.no_grad():
        outputs = model(**encodings)

    # Collect activated neurons
    sample_activated_neurons = {}
    for key, value in activated_neurons.items():
        try:
            if len(value.shape) >= 2:
                # Remove batch dimension
                mask = value.squeeze(0).numpy()

                # For each position, record the indices of activated neurons
                if len(mask.shape) == 2:  # [seq_len, hidden_dim]
                    position_neurons = {}
                    for pos in range(mask.shape[0]):
                        active_indices = np.where(mask[pos])[0]
                        if len(active_indices) > 0:
                            position_neurons[pos] = active_indices.tolist()

                    sample_activated_neurons[key] = position_neurons
                elif len(mask.shape) == 1:  # [hidden_dim]
                    active_indices = np.where(mask)[0]
                    if len(active_indices) > 0:
                        sample_activated_neurons[key] = {0: active_indices.tolist()}
        except Exception as e:
            pass

    # Return sample information
    return {
        'sample_id': sample_id,
        'text': sample['text'],
        'label': sample['label'],
        'activated_neurons': sample_activated_neurons,
        'input_ids': encodings['input_ids'][0].cpu().numpy(),
    }

def collect_activations(samples, batch_size=10):
    """Collect activation data for samples in batches"""
    results = []

    # Process samples in batches to save memory
    for i in range(0, len(samples), batch_size):
        batch = samples[i:i+batch_size]
        batch_results = []

        for j, sample in enumerate(batch):
            try:
                result = process_sample(sample, i + j)
                batch_results.append(result)
            except Exception as e:
                pass

        results.extend(batch_results)

        # Clean up memory
        if i + batch_size < len(samples):  # Not the last batch
            del batch_results
            gc.collect()
            torch.cuda.empty_cache()

    return results

# 8. Analyze neuron activation patterns
def analyze_neuron_activation_patterns(member_neurons, nonmember_neurons):
    """Analyze the activation patterns of neurons in member and non-member samples"""
    results = {}

    # Get all layer names
    layer_names = set()
    for sample in member_neurons + nonmember_neurons:
        layer_names.update(sample['activated_neurons'].keys())

    # Analyze layer by layer
    for layer_name in layer_names:
        # Count activated neurons in member and non-member samples
        member_neuron_counts = Counter()
        nonmember_neuron_counts = Counter()

        # Activated neurons in member samples
        for sample in member_neurons:
            if layer_name in sample['activated_neurons']:
                for pos_neurons in sample['activated_neurons'][layer_name].values():
                    member_neuron_counts.update(pos_neurons)

        # Activated neurons in non-member samples
        for sample in nonmember_neurons:
            if layer_name in sample['activated_neurons']:
                for pos_neurons in sample['activated_neurons'][layer_name].values():
                    nonmember_neuron_counts.update(pos_neurons)

        # Calculate the activation frequency for each neuron
        member_freq = {n: count / len(member_neurons) for n, count in member_neuron_counts.items()} if member_neurons else {}
        nonmember_freq = {n: count / len(nonmember_neurons) for n, count in nonmember_neuron_counts.items()} if nonmember_neurons else {}

        # Identify neurons predominantly activated in member samples
        member_dominant = {}
        for neuron, freq in member_freq.items():
            if neuron not in nonmember_freq or freq > nonmember_freq[neuron] * 1.5:
                member_dominant[neuron] = freq

        # Identify neurons predominantly activated in non-member samples
        nonmember_dominant = {}
        for neuron, freq in nonmember_freq.items():
            if neuron not in member_freq or freq > member_freq[neuron] * 1.5:
                nonmember_dominant[neuron] = freq

        # Identify neurons frequently activated in both types of samples
        common_neurons = {}
        for neuron in set(member_freq.keys()) & set(nonmember_freq.keys()):
            if neuron not in member_dominant and neuron not in nonmember_dominant:
                common_neurons[neuron] = (member_freq[neuron], nonmember_freq[neuron])

        # Save the results
        results[layer_name] = {
            'member_dominant': member_dominant,
            'nonmember_dominant': nonmember_dominant,
            'common_neurons': common_neurons,
            'member_counts': member_neuron_counts,
            'nonmember_counts': nonmember_neuron_counts,
            'member_freq': member_freq,
            'nonmember_freq': nonmember_freq
        }

    return results

# 9. Build reference patterns
def build_reference_patterns(train_samples):
    """Build reference activation patterns using training samples"""
    # Separate member and non-member samples
    member_samples = [s for s in train_samples if s['label'] == 1]
    nonmember_samples = [s for s in train_samples if s['label'] == 0]

    # Analyze activation pattern differences
    reference_patterns = analyze_neuron_activation_patterns(member_samples, nonmember_samples)

    return reference_patterns

# 10. Calculate discrimination score for each layer
def calculate_layer_discrimination_scores(reference_patterns):
    """Calculate the discriminative power score for each layer"""
    layer_scores = {}

    # Calculate the discriminative power score for each layer
    for layer_name, data in reference_patterns.items():
        # Simple discrimination score: number of member-dominant neurons minus number of non-member-dominant neurons
        # If the score is greater than 0, it indicates the layer is more inclined to identify member samples
        member_dominant_count = len(data['member_dominant'])
        nonmember_dominant_count = len(data['nonmember_dominant'])

        discrimination_score = member_dominant_count - nonmember_dominant_count
        layer_scores[layer_name] = discrimination_score

    return layer_scores

# 11. Select the most discriminative layers
def select_discriminative_layers(layer_scores, top_n=10):
    """Select the most discriminative layers"""
    # Sort by discrimination score
    sorted_scores = sorted(layer_scores.items(), key=lambda x: abs(x[1]), reverse=True)

    # Select the top N layers
    selected_layers = [layer for layer, score in sorted_scores[:top_n]]

    return selected_layers

# 12. Membership prediction based on relative ratio
def predict_membership_by_relative_ratio(sample, reference_patterns, discriminative_layers, threshold=1.0):
    """Predict whether a single sample is a member using the relative ratio method"""
    if not sample['activated_neurons']:
        return 0, 0, 0  # If there is no activation data, default to non-member

    # Initialize relative ratios
    layers_counted = 0
    total_member_ratio = 0
    total_nonmember_ratio = 0

    # Analyze layer by layer
    for layer_name in discriminative_layers:
        if layer_name not in sample['activated_neurons'] or layer_name not in reference_patterns:
            continue

        # Get the reference pattern for this layer
        layer_data = reference_patterns[layer_name]

        # Get all neurons activated by this sample in this layer
        sample_neurons = set()
        for pos_neurons in sample['activated_neurons'][layer_name].values():
            sample_neurons.update(pos_neurons)

        if not sample_neurons:
            continue

        # Calculate the relative overlap with member-dominant neurons
        member_dominant_set = set(layer_data['member_dominant'].keys())
        member_overlap = len(sample_neurons.intersection(member_dominant_set))

        # Calculate the relative overlap with non-member-dominant neurons
        nonmember_dominant_set = set(layer_data['nonmember_dominant'].keys())
        nonmember_overlap = len(sample_neurons.intersection(nonmember_dominant_set))

        # Calculate relative ratios
        member_ratio = member_overlap / len(member_dominant_set) if len(member_dominant_set) > 0 else 0
        nonmember_ratio = nonmember_overlap / len(nonmember_dominant_set) if len(nonmember_dominant_set) > 0 else 0

        # Accumulate relative ratios
        total_member_ratio += member_ratio
        total_nonmember_ratio += nonmember_ratio
        layers_counted += 1

    # If there are no valid layers, default to non-member
    if layers_counted == 0:
        return 0, 0, 0

    # Calculate average relative ratios
    avg_member_ratio = total_member_ratio / layers_counted
    avg_nonmember_ratio = total_nonmember_ratio / layers_counted

    # Calculate the ratio of relative ratios
    if avg_nonmember_ratio == 0:
        ratio = float('inf')  # Avoid division by zero
    else:
        ratio = avg_member_ratio / avg_nonmember_ratio

    # Return the prediction and related ratios
    return 1 if ratio >= threshold else 0, avg_member_ratio, avg_nonmember_ratio

# 14. Find the best threshold on the validation set
def find_best_relative_ratio_threshold(val_samples, reference_patterns, discriminative_layers):
    """Find the best threshold on the validation set using the relative ratio method"""
    # Calculate relative ratios for all validation samples
    ratios = []
    labels = []

    for sample in val_samples:
        pred, member_ratio, nonmember_ratio = predict_membership_by_relative_ratio(
            sample, reference_patterns, discriminative_layers, threshold=1.0
        )

        # Calculate the ratio of relative ratios
        ratio = float('inf') if nonmember_ratio == 0 else member_ratio / nonmember_ratio

        ratios.append(ratio)
        labels.append(sample['label'])

    # Filter out infinite values
    filtered_ratios = []
    filtered_labels = []
    for r, l in zip(ratios, labels):
        if r != float('inf') and not np.isnan(r):
            filtered_ratios.append(r)
            filtered_labels.append(l)

    if not filtered_ratios: # If there are no valid ratios
        return 1.0, {'accuracy': 0.5, 'precision': 0, 'recall': 0, 'f1': 0}

    # Create candidate thresholds
    min_ratio = min(filtered_ratios)
    max_ratio = max(filtered_ratios)

    # Generate uniformly distributed thresholds
    candidate_thresholds = list(np.linspace(min_ratio, max_ratio, 100))
    # Add some important threshold points
    candidate_thresholds += [0.5, 1.0, 1.5, 2.0, 3.0, 5.0]
    candidate_thresholds = sorted(set(candidate_thresholds))

    # Find the best threshold
    best_threshold = 1.0
    best_accuracy = 0
    best_metrics = None

    results = []
    for threshold in candidate_thresholds:
        # Make predictions using the current threshold
        predictions = [1 if r >= threshold else 0 for r in ratios]

        # Calculate performance metrics
        accuracy = accuracy_score(labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary', zero_division=0)

        # Record the results
        results.append({
            'threshold': threshold,
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1
        })

        # Update the best threshold
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_threshold = threshold
            best_metrics = {
                'accuracy': accuracy,
                'precision': precision,
                'recall': recall,
                'f1': f1
            }

    return best_threshold, best_metrics

# Calculate AUC and TPR@5%FPR for the test set
def calculate_auc_and_tpr(test_samples, reference_patterns, discriminative_layers, batch_size=10):
    """Calculate the AUC value and TPR@5%FPR for the test set"""
    # Process test samples in batches
    all_ratios = []
    all_labels = []

    for i in range(0, len(test_samples), batch_size):
        batch = test_samples[i:i+batch_size]

        # Collect activation data
        batch_acts = collect_activations(batch, batch_size=batch_size)

        for sample in batch_acts:
            # Calculate relative ratios and predict
            _, member_ratio, nonmember_ratio = predict_membership_by_relative_ratio(
                sample, reference_patterns, discriminative_layers
            )

            # Calculate the ratio of relative ratios
            ratio = float('inf') if nonmember_ratio == 0 else member_ratio / nonmember_ratio

            # Record the results
            all_ratios.append(ratio)
            all_labels.append(sample['label'])

        # Release memory
        del batch_acts
        gc.collect()
        torch.cuda.empty_cache()

    # Filter out infinite values for ROC analysis
    filtered_ratios = []
    filtered_labels = []
    for ratio, label in zip(all_ratios, all_labels):
        if ratio != float('inf') and not np.isnan(ratio):
            filtered_ratios.append(ratio)
            filtered_labels.append(label)

    if len(set(filtered_labels)) < 2:
        return 0.5, 0.0 # Cannot calculate metrics

    # Calculate AUC
    fpr, tpr, _ = roc_curve(filtered_labels, filtered_ratios)
    roc_auc = auc(fpr, tpr)

    # Calculate TPR@5%FPR
    tpr_at_fpr = 0.0
    try:
        target_fpr = 0.05
        target_index = np.where(fpr >= target_fpr)[0][0]
        tpr_at_fpr = tpr[target_index]
    except IndexError:
        # If FPR never reaches 5%, then TPR@5%FPR is 0 or handled according to the specific case
        tpr_at_fpr = 0.0

    return roc_auc, tpr_at_fpr

# Process all thresholds for a single dataset
def process_dataset(data_type, data_dir="/content/drive/MyDrive/LLM_MIA/data", thresholds=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0]):
    """Process all thresholds for a single dataset"""
    global ACTIVATION_THRESHOLD

    print(f"\n===== Dataset: {data_type} =====")

    # --- Core Change: Remove ds_train and load dev and test based on the data_type condition ---
    if data_type == "prompt":
        dev_path = os.path.join(data_dir, "pile_cc_mia_dev_prompt_gpt4o_simple_v1")
        test_path = os.path.join(data_dir, "pile_cc_mia_test_prompt_gpt4o_simple_v1")
        print(f"Loading specific datasets for 'prompt' type:")
        print(f"  - DEV: {dev_path}")
        print(f"  - TEST: {test_path}")
        ds_dev = load_from_disk(dev_path)
        ds_test = load_from_disk(test_path)
    else:
        # Keep the original loading logic for bt and bert
        ds_dev = load_from_disk(os.path.join(data_dir, f"pile_cc_mia_dev_{data_type}"))
        ds_test = load_from_disk(os.path.join(data_dir, f"pile_cc_mia_test_{data_type}"))
    # --- End of change ---

    # Process for each threshold
    for threshold in thresholds:
        # Set the current threshold
        ACTIVATION_THRESHOLD = threshold

        # 1. Get member and non-member samples from the development set
        dev_member = [item for item in ds_dev if item['label'] == 1]
        dev_nonmember = [item for item in ds_dev if item['label'] == 0]

        # 2. Split the development set evenly into training and validation sets
        train_member = dev_member[:len(dev_member)//2]
        train_nonmember = dev_nonmember[:len(dev_nonmember)//2]

        val_member = dev_member[len(dev_member)//2:]
        val_nonmember = dev_nonmember[len(dev_nonmember)//2:]

        # 3. Collect activation data
        train_member_acts = collect_activations(train_member)
        train_nonmember_acts = collect_activations(train_nonmember)

        val_member_acts = collect_activations(val_member)
        val_nonmember_acts = collect_activations(val_nonmember)

        # 4. Build reference activation patterns
        reference_patterns = build_reference_patterns(
            train_member_acts + train_nonmember_acts
        )

        # 5. Calculate layer discrimination scores
        layer_scores = calculate_layer_discrimination_scores(reference_patterns)

        # 6. Select the most discriminative layers
        discriminative_layers = select_discriminative_layers(layer_scores, top_n=10)

        # 7. Find the best threshold on the validation set using the relative ratio method
        val_samples = val_member_acts + val_nonmember_acts
        best_threshold, val_metrics = find_best_relative_ratio_threshold(
            val_samples,
            reference_patterns,
            discriminative_layers
        )

        # 8. Calculate test set AUC and TPR@5%FPR
        test_auc, test_tpr_at_fpr = calculate_auc_and_tpr(list(ds_test), reference_patterns, discriminative_layers)

        # 9. Print results
        print(f"Dataset: {data_type}, Threshold: {threshold:.1f}, Test Set AUC = {test_auc:.4f}, TPR@5%FPR (%) = {test_tpr_at_fpr * 100:.2f}%")

        # 10. Clean up memory
        del train_member_acts, train_nonmember_acts, val_member_acts, val_nonmember_acts
        del reference_patterns, layer_scores, discriminative_layers, val_samples
        gc.collect()
        torch.cuda.empty_cache()

    # Release memory after processing the current dataset
    del ds_dev, ds_test
    gc.collect()
    torch.cuda.empty_cache()

# Main function
if __name__ == "__main__":
    try:
        # Set the activation thresholds to be evaluated
        activation_thresholds = [1.0]

        # Process the three datasets sequentially
        datasets = ["bt", "bert", "prompt"]

        for dataset in datasets:
            process_dataset(dataset, thresholds=activation_thresholds)

        # After all datasets are processed, remove the hooks
        for hook in hooks:
            hook.remove()

    except Exception as e:
        print(f"An error occurred during execution: {str(e)}")
        # Ensure hooks are removed
        for hook in hooks:
            try:
                hook.remove()
            except:
                pass
        raise e

## WikiMia

### Pythia-2.8b

In [None]:
# Import necessary libraries
import os
import torch
import numpy as np
from collections import Counter
from datasets import load_from_disk, load_dataset, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics import roc_curve, auc, accuracy_score, precision_recall_fscore_support
import gc
import warnings
warnings.filterwarnings("ignore")

# Disable tqdm - a safer way
import sys
import importlib
import tqdm as tqdm_module

# Save the original module before loading tqdm
original_module = sys.modules.get('tqdm', None)

# Create a no-op tqdm substitute
class DummyTqdmModule:
    def __init__(self, *args, **kwargs):
        pass

    def update(self, *args, **kwargs):
        pass

    def close(self, *args, **kwargs):
        pass

    def __iter__(self):
        return self

    def __next__(self):
        raise StopIteration

# Override all methods of the tqdm class
def dummy_tqdm(*args, **kwargs):
    if len(args) > 0 and isinstance(args[0], list):
        return args[0]
    return DummyTqdmModule()

# Patch all possible tqdm versions
for name in ['tqdm', 'tqdm.std', 'tqdm.auto', 'tqdm.notebook', 'tqdm.rich', 'tqdm.cli', 'tqdm.gui', 'tqdm.keras']:
    try:
        if name in sys.modules:
            module = sys.modules[name]
            module.tqdm = dummy_tqdm
    except:
        pass

tqdm_module.tqdm = dummy_tqdm

# 1. Set up the environment
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 3. Load the model
model_name = "EleutherAI/pythia-2.8b"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Ensure the model is on the correct device
model = model.to(device)
model.eval()  # Set to evaluation mode

# If the tokenizer does not have a padding token, set one
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# 4. Set up hook functions to collect neuron activations
activations = {}
activated_neurons = {}
ACTIVATION_THRESHOLD = 0  # Set activation threshold

def get_ffn_activation(name):
    """Hook function to capture FFN activations"""
    def hook(module, input, output):
        # Save activation values
        activations[name] = output.detach().cpu()

        # Identify activated neurons (neurons exceeding the threshold)
        if output.dim() >= 2:
            # If it's a 3D tensor [batch_size, seq_len, hidden_dim]
            if output.dim() == 3:
                # For each sample and each position, find neurons with activation exceeding the threshold
                activation_mask = (output > ACTIVATION_THRESHOLD).detach().cpu()
                activated_neurons[name] = activation_mask
            # If it's a 2D tensor [batch_size, hidden_dim]
            elif output.dim() == 2:
                activation_mask = (output > ACTIVATION_THRESHOLD).detach().cpu()
                activated_neurons[name] = activation_mask
    return hook

# 5. Register hooks for the model
hooks = []

# Register hooks for FFN layers
for i, layer in enumerate(model.gpt_neox.layers):
    # Register MLP activation function hook
    if hasattr(layer.mlp, 'act'):
        hook = layer.mlp.act.register_forward_hook(get_ffn_activation(f'layer_{i}_ffn_act'))
        hooks.append(hook)
    else:
        # If the 'act' attribute does not exist, try to find the correct activation layer
        for name, module in layer.mlp.named_modules():
            if isinstance(module, torch.nn.GELU) or isinstance(module, torch.nn.ReLU):
                hook = module.register_forward_hook(get_ffn_activation(f'layer_{i}_ffn_act_{name}'))
                hooks.append(hook)
                break

# 7. Collect sample activation data
def process_sample(sample, sample_id):
    """Process a single sample and collect activations"""
    # Clear previous activations
    activations.clear()
    activated_neurons.clear()

    # Prepare input for the sample
    encodings = tokenizer(sample['text'], return_tensors="pt", truncation=True, max_length=512)
    encodings = {k: v.to(device) for k, v in encodings.items()}

    # Run the model
    with torch.no_grad():
        outputs = model(**encodings)

    # Collect activated neurons
    sample_activated_neurons = {}
    for key, value in activated_neurons.items():
        try:
            if len(value.shape) >= 2:
                # Remove batch dimension
                mask = value.squeeze(0).numpy()

                # For each position, record the indices of activated neurons
                if len(mask.shape) == 2:  # [seq_len, hidden_dim]
                    position_neurons = {}
                    for pos in range(mask.shape[0]):
                        active_indices = np.where(mask[pos])[0]
                        if len(active_indices) > 0:
                            position_neurons[pos] = active_indices.tolist()

                    sample_activated_neurons[key] = position_neurons
                elif len(mask.shape) == 1:  # [hidden_dim]
                    active_indices = np.where(mask)[0]
                    if len(active_indices) > 0:
                        sample_activated_neurons[key] = {0: active_indices.tolist()}
        except Exception as e:
            pass

    # Return sample information
    return {
        'sample_id': sample_id,
        'text': sample['text'],
        'label': sample['label'],
        'activated_neurons': sample_activated_neurons,
        'input_ids': encodings['input_ids'][0].cpu().numpy(),
    }

def collect_activations(samples, batch_size=10):
    """Collect activation data for samples in batches"""
    results = []

    # Process samples in batches to save memory
    for i in range(0, len(samples), batch_size):
        batch = samples[i:i+batch_size]
        batch_results = []

        for j, sample in enumerate(batch):
            try:
                result = process_sample(sample, i + j)
                batch_results.append(result)
            except Exception as e:
                pass

        results.extend(batch_results)

        # Clean up memory
        if i + batch_size < len(samples):  # Not the last batch
            del batch_results
            gc.collect()
            torch.cuda.empty_cache()

    return results

# 8. Analyze neuron activation patterns
def analyze_neuron_activation_patterns(member_neurons, nonmember_neurons):
    """Analyze the activation patterns of neurons in member and non-member samples"""
    results = {}

    # Get all layer names
    layer_names = set()
    for sample in member_neurons + nonmember_neurons:
        layer_names.update(sample['activated_neurons'].keys())

    # Analyze layer by layer
    for layer_name in layer_names:
        # Count activated neurons in member and non-member samples
        member_neuron_counts = Counter()
        nonmember_neuron_counts = Counter()

        # Activated neurons in member samples
        for sample in member_neurons:
            if layer_name in sample['activated_neurons']:
                for pos_neurons in sample['activated_neurons'][layer_name].values():
                    member_neuron_counts.update(pos_neurons)

        # Activated neurons in non-member samples
        for sample in nonmember_neurons:
            if layer_name in sample['activated_neurons']:
                for pos_neurons in sample['activated_neurons'][layer_name].values():
                    nonmember_neuron_counts.update(pos_neurons)

        # Calculate the activation frequency for each neuron
        member_freq = {n: count / len(member_neurons) for n, count in member_neuron_counts.items()}
        nonmember_freq = {n: count / len(nonmember_neurons) for n, count in nonmember_neuron_counts.items()}

        # Identify neurons predominantly activated in member samples
        member_dominant = {}
        for neuron, freq in member_freq.items():
            if neuron not in nonmember_freq or freq > nonmember_freq[neuron] * 1.5:
                member_dominant[neuron] = freq

        # Identify neurons predominantly activated in non-member samples
        nonmember_dominant = {}
        for neuron, freq in nonmember_freq.items():
            if neuron not in member_freq or freq > nonmember_freq[neuron] * 1.5:
                nonmember_dominant[neuron] = freq

        # Identify neurons frequently activated in both types of samples
        common_neurons = {}
        for neuron in set(member_freq.keys()) & set(nonmember_freq.keys()):
            if neuron not in member_dominant and neuron not in nonmember_dominant:
                common_neurons[neuron] = (member_freq[neuron], nonmember_freq[neuron])

        # Save the results
        results[layer_name] = {
            'member_dominant': member_dominant,
            'nonmember_dominant': nonmember_dominant,
            'common_neurons': common_neurons,
            'member_counts': member_neuron_counts,
            'nonmember_counts': nonmember_neuron_counts,
            'member_freq': member_freq,
            'nonmember_freq': nonmember_freq
        }

    return results

# 9. Build reference patterns
def build_reference_patterns(train_samples):
    """Build reference activation patterns using training samples"""
    # Separate member and non-member samples
    member_samples = [s for s in train_samples if s['label'] == 1]
    nonmember_samples = [s for s in train_samples if s['label'] == 0]

    # Analyze activation pattern differences
    reference_patterns = analyze_neuron_activation_patterns(member_samples, nonmember_samples)

    return reference_patterns

# 10. Calculate discrimination score for each layer
def calculate_layer_discrimination_scores(reference_patterns):
    """Calculate the discriminative power score for each layer"""
    layer_scores = {}

    # Calculate the discriminative power score for each layer
    for layer_name, data in reference_patterns.items():
        # Simple discrimination score: number of member-dominant neurons minus number of non-member-dominant neurons
        # If the score is greater than 0, it indicates the layer is more inclined to identify member samples
        member_dominant_count = len(data['member_dominant'])
        nonmember_dominant_count = len(data['nonmember_dominant'])

        discrimination_score = member_dominant_count - nonmember_dominant_count
        layer_scores[layer_name] = discrimination_score

    return layer_scores

# 11. Select the most discriminative layers
def select_discriminative_layers(layer_scores, top_n=10):
    """Select the most discriminative layers"""
    # Sort by discrimination score
    sorted_scores = sorted(layer_scores.items(), key=lambda x: abs(x[1]), reverse=True)

    # Select the top N layers
    selected_layers = [layer for layer, score in sorted_scores[:top_n]]

    return selected_layers

# 12. Membership prediction based on relative ratio
def predict_membership_by_relative_ratio(sample, reference_patterns, discriminative_layers, threshold=1.0):
    """Predict whether a single sample is a member using the relative ratio method"""
    if not sample['activated_neurons']:
        return 0, 0, 0  # If there is no activation data, default to non-member

    # Initialize relative ratios
    layers_counted = 0
    total_member_ratio = 0
    total_nonmember_ratio = 0

    # Analyze layer by layer
    for layer_name in discriminative_layers:
        if layer_name not in sample['activated_neurons'] or layer_name not in reference_patterns:
            continue

        # Get the reference pattern for this layer
        layer_data = reference_patterns[layer_name]

        # Get all neurons activated by this sample in this layer
        sample_neurons = set()
        for pos_neurons in sample['activated_neurons'][layer_name].values():
            sample_neurons.update(pos_neurons)

        if not sample_neurons:
            continue

        # Calculate the relative overlap with member-dominant neurons
        member_dominant_set = set(layer_data['member_dominant'].keys())
        member_overlap = len(sample_neurons.intersection(member_dominant_set))

        # Calculate the relative overlap with non-member-dominant neurons
        nonmember_dominant_set = set(layer_data['nonmember_dominant'].keys())
        nonmember_overlap = len(sample_neurons.intersection(nonmember_dominant_set))

        # Calculate relative ratios
        member_ratio = member_overlap / len(member_dominant_set) if len(member_dominant_set) > 0 else 0
        nonmember_ratio = nonmember_overlap / len(nonmember_dominant_set) if len(nonmember_dominant_set) > 0 else 0

        # Accumulate relative ratios
        total_member_ratio += member_ratio
        total_nonmember_ratio += nonmember_ratio
        layers_counted += 1

    # If there are no valid layers, default to non-member
    if layers_counted == 0:
        return 0, 0, 0

    # Calculate average relative ratios
    avg_member_ratio = total_member_ratio / layers_counted
    avg_nonmember_ratio = total_nonmember_ratio / layers_counted

    # Calculate the ratio of relative ratios
    if avg_nonmember_ratio == 0:
        ratio = float('inf')  # Avoid division by zero
    else:
        ratio = avg_member_ratio / avg_nonmember_ratio

    # Return the prediction and related ratios
    return 1 if ratio >= threshold else 0, avg_member_ratio, avg_nonmember_ratio

# 14. Find the best threshold on the validation set
def find_best_relative_ratio_threshold(val_samples, reference_patterns, discriminative_layers):
    """Find the best threshold on the validation set using the relative ratio method"""
    # Calculate relative ratios for all validation samples
    ratios = []
    labels = []

    for sample in val_samples:
        pred, member_ratio, nonmember_ratio = predict_membership_by_relative_ratio(
            sample, reference_patterns, discriminative_layers, threshold=1.0
        )

        # Calculate the ratio of relative ratios
        ratio = float('inf') if nonmember_ratio == 0 else member_ratio / nonmember_ratio

        ratios.append(ratio)
        labels.append(sample['label'])

    # Filter out infinite values
    filtered_ratios = []
    filtered_labels = []
    for r, l in zip(ratios, labels):
        if r != float('inf') and not np.isnan(r):
            filtered_ratios.append(r)
            filtered_labels.append(l)

    # Create candidate thresholds
    min_ratio = min(filtered_ratios)
    max_ratio = max(filtered_ratios)

    # Generate uniformly distributed thresholds
    candidate_thresholds = list(np.linspace(min_ratio, max_ratio, 100))
    # Add some important threshold points
    candidate_thresholds += [0.5, 1.0, 1.5, 2.0, 3.0, 5.0]
    candidate_thresholds = sorted(set(candidate_thresholds))

    # Find the best threshold
    best_threshold = 1.0
    best_accuracy = 0
    best_metrics = None

    results = []
    for threshold in candidate_thresholds:
        # Make predictions using the current threshold
        predictions = [1 if r >= threshold else 0 for r in ratios]

        # Calculate performance metrics
        accuracy = accuracy_score(labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')

        # Record the results
        results.append({
            'threshold': threshold,
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1
        })

        # Update the best threshold
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_threshold = threshold
            best_metrics = {
                'accuracy': accuracy,
                'precision': precision,
                'recall': recall,
                'f1': f1
            }

    return best_threshold, best_metrics

# Calculate AUC for the test set
def calculate_auc(test_samples, reference_patterns, discriminative_layers, batch_size=10):
    """Calculate the AUC value for the test set"""
    # Process test samples in batches
    all_ratios = []
    all_labels = []

    for i in range(0, len(test_samples), batch_size):
        batch = test_samples[i:i+batch_size]

        # Collect activation data
        batch_acts = collect_activations(batch, batch_size=batch_size)

        for sample in batch_acts:
            # Calculate relative ratios and predict
            _, member_ratio, nonmember_ratio = predict_membership_by_relative_ratio(
                sample, reference_patterns, discriminative_layers
            )

            # Calculate the ratio of relative ratios
            ratio = float('inf') if nonmember_ratio == 0 else member_ratio / nonmember_ratio

            # Record the results
            all_ratios.append(ratio)
            all_labels.append(sample['label'])

        # Release memory
        del batch_acts
        gc.collect()
        torch.cuda.empty_cache()

    # Filter out infinite values for ROC analysis
    filtered_ratios = []
    filtered_labels = []
    for ratio, label in zip(all_ratios, all_labels):
        if ratio != float('inf') and not np.isnan(ratio):
            filtered_ratios.append(ratio)
            filtered_labels.append(label)

    # Calculate AUC
    fpr, tpr, _ = roc_curve(filtered_labels, filtered_ratios)
    roc_auc = auc(fpr, tpr)

    return roc_auc

# Simply iterate over samples without using tqdm
def iterate_samples(samples):
    """Simply iterate over samples, avoiding tqdm"""
    return samples

# Process WikiMIA dataset
def process_wikimia(thresholds=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0]):
    """Process all thresholds for the WikiMIA dataset"""
    global ACTIVATION_THRESHOLD

    print("\n===== Dataset: WikiMIA =====")

    # Load the WikiMIA dataset
    try:
        # First, try to load the prepared dataset directly
        ds_dev = load_from_disk("data/ds_dev")
        ds_test = load_from_disk("data/ds_test")
        print("Loaded prepared dataset from disk")
    except:
        # If the prepared dataset is not found, process from scratch
        print("Prepared dataset not found, processing WikiMIA data from scratch...")

        # Load the WikiMIA dataset
        ds = load_dataset("swj0419/WikiMIA")

        # Select only the length32 data
        length32_data = ds["WikiMIA_length32"]

        # Separate member (label=1) and non-member (label=0) samples, and rename the 'input' field to 'text'
        member_samples = []
        nonmember_samples = []

        for item in length32_data:
            # Create a new sample dictionary, changing the 'input' field to 'text'
            new_item = {
                'text': item['input'],
                'label': item['label']
            }
            # Keep other existing fields
            for key, value in item.items():
                if key not in ['input', 'text', 'label']:
                    new_item[key] = value

            # Classify based on the label
            if item['label'] == 1:
                member_samples.append(new_item)
            else:
                nonmember_samples.append(new_item)

        print(f"Number of member samples: {len(member_samples)}")
        print(f"Number of non-member samples: {len(nonmember_samples)}")

        # Adjust dataset sizes
        # Use the number of available samples to determine the split
        available_member_count = len(member_samples)
        available_nonmember_count = len(nonmember_samples)

        # Set the number of member samples in the dev and test sets
        # Considering there are only 387 member samples, we can take 140 for dev and the remaining 247 for test
        dev_member_count = 140
        test_member_count = available_member_count - dev_member_count  # should be 247

        # Set the same number of non-member samples to maintain balance
        dev_nonmember_count = dev_member_count
        test_nonmember_count = test_member_count

        # Set a random seed for reproducibility
        np.random.seed(42)

        # Split the member samples
        dev_members = member_samples[:dev_member_count]
        test_members = member_samples[dev_member_count:]

        # Split the non-member samples
        dev_nonmembers = nonmember_samples[:dev_nonmember_count]
        test_nonmembers = nonmember_samples[dev_nonmember_count:dev_nonmember_count+test_nonmember_count]

        # Combine the dev set
        ds_dev_list = dev_members + dev_nonmembers
        np.random.shuffle(ds_dev_list)  # Shuffle the order

        # Combine the test set
        ds_test_list = test_members + test_nonmembers
        np.random.shuffle(ds_test_list)  # Shuffle the order

        # Verify dataset sizes and label balance
        print(f"Dev set size: {len(ds_dev_list)}")
        print(f"Number of member samples in Dev set: {sum(1 for item in ds_dev_list if item['label'] == 1)}")
        print(f"Number of non-member samples in Dev set: {sum(1 for item in ds_dev_list if item['label'] == 0)}")

        print(f"Test set size: {len(ds_test_list)}")
        print(f"Number of member samples in Test set: {sum(1 for item in ds_test_list if item['label'] == 1)}")
        print(f"Number of non-member samples in Test set: {sum(1 for item in ds_test_list if item['label'] == 0)}")

        # Convert to Dataset format
        ds_dev = Dataset.from_list(ds_dev_list)
        ds_test = Dataset.from_list(ds_test_list)

        # Create the data directory if it doesn't exist
        os.makedirs("data", exist_ok=True)

        # Save to disk
        ds_dev.save_to_disk("data/ds_dev")
        ds_test.save_to_disk("data/ds_test")

        print("Datasets have been saved to the data/ds_dev and data/ds_test directories")

    # Process for each threshold
    for threshold in thresholds:
        # Set the current threshold
        ACTIVATION_THRESHOLD = threshold

        # 1. Get member and non-member samples from the development set
        dev_member = [item for item in ds_dev if item['label'] == 1]
        dev_nonmember = [item for item in ds_dev if item['label'] == 0]

        # 2. Split the development set evenly into training and validation sets
        train_member = dev_member[:len(dev_member)//2]
        train_nonmember = dev_nonmember[:len(dev_nonmember)//2]

        val_member = dev_member[len(dev_member)//2:]
        val_nonmember = dev_nonmember[len(dev_nonmember)//2:]

        # 3. Collect activation data
        train_member_acts = collect_activations(train_member)
        train_nonmember_acts = collect_activations(train_nonmember)

        val_member_acts = collect_activations(val_member)
        val_nonmember_acts = collect_activations(val_nonmember)

        # 4. Build reference activation patterns
        reference_patterns = build_reference_patterns(
            train_member_acts + train_nonmember_acts
        )

        # 5. Calculate layer discrimination scores
        layer_scores = calculate_layer_discrimination_scores(reference_patterns)

        # 6. Select the most discriminative layers
        discriminative_layers = select_discriminative_layers(layer_scores, top_n=10)

        # 7. Find the best threshold on the validation set using the relative ratio method
        val_samples = val_member_acts + val_nonmember_acts
        best_threshold, val_metrics = find_best_relative_ratio_threshold(
            val_samples,
            reference_patterns,
            discriminative_layers
        )

        # 8. Calculate test set AUC
        test_auc = calculate_auc(list(ds_test), reference_patterns, discriminative_layers)

        # 9. Print only the dataset, threshold, and test set AUC value
        print(f"Dataset: WikiMIA, Threshold: {threshold:.1f}, Test Set AUC = {test_auc:.4f}")

        # 10. Clean up memory
        del train_member_acts, train_nonmember_acts, val_member_acts, val_nonmember_acts
        del reference_patterns, layer_scores, discriminative_layers, val_samples
        gc.collect()
        torch.cuda.empty_cache()

    # Release memory after processing the current dataset
    del ds_dev, ds_test
    gc.collect()
    torch.cuda.empty_cache()

# Main function
if __name__ == "__main__":
    try:
        # Set the activation thresholds to be evaluated
        activation_thresholds = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0]

        # Only process the WikiMIA dataset
        process_wikimia(thresholds=activation_thresholds)

        # After processing, remove the hooks
        for hook in hooks:
            hook.remove()

    except Exception as e:
        print(f"An error occurred during execution: {str(e)}")
        # Ensure hooks are removed
        for hook in hooks:
            try:
                hook.remove()
            except:
                pass
        raise e

### Opt-6.7b

In [None]:
# Import necessary libraries
import os
import torch
import numpy as np
from collections import Counter
from datasets import load_from_disk, load_dataset, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics import roc_curve, auc, accuracy_score, precision_recall_fscore_support
import gc
import warnings
warnings.filterwarnings("ignore")

# Disable tqdm - a safer way
import sys
import importlib
# Completely disable tqdm
import tqdm

# Define a function that does nothing and returns its input
def dummy_tqdm(iterable=None, *args, **kwargs):
    return iterable if iterable is not None else dummy_tqdm

# Add all necessary attributes and methods
dummy_tqdm.format_interval = lambda x: f"{x:.1f}s"
dummy_tqdm.format_meter = lambda *args, **kwargs: ""
dummy_tqdm.format_num = lambda x: str(x)
dummy_tqdm.status_printer = lambda *args, **kwargs: lambda x: None
dummy_tqdm.get_lock = lambda: None
dummy_tqdm.set_lock = lambda x: None
dummy_tqdm.display = lambda *args, **kwargs: None
dummy_tqdm.clear = lambda *args, **kwargs: None
dummy_tqdm.close = lambda *args, **kwargs: None
dummy_tqdm.update = lambda *args, **kwargs: None
dummy_tqdm.refresh = lambda *args, **kwargs: None
dummy_tqdm.disable = True
dummy_tqdm.monitor_interval = 0
dummy_tqdm.monitor = None
dummy_tqdm.pos = 0
dummy_tqdm.__iter__ = lambda self: iter([])
dummy_tqdm.__next__ = lambda self: next(iter([]))

# Replace all tqdm variants
tqdm.tqdm = dummy_tqdm
tqdm.std.tqdm = dummy_tqdm
tqdm.notebook.tqdm = dummy_tqdm
tqdm.auto.tqdm = dummy_tqdm
tqdm.gui.tqdm = dummy_tqdm
tqdm.cli.tqdm = dummy_tqdm
tqdm.__call__ = dummy_tqdm

# --- Patch missing tqdm.format_sizeof --- #
def _dummy_format_sizeof(num_bytes, *args, **kwargs):
    """
    A fake tqdm.format_sizeof.
    Just returns a simple string of the byte count, enough to trick transformers,
    without affecting your complete disabling of tqdm.
    """
    return f"{num_bytes}"

# If you already have a dummy_tqdm object:
try:
    dummy_tqdm.format_sizeof = _dummy_format_sizeof
except NameError:
    pass  # Skip if dummy_tqdm does not exist

# Also patch the real tqdm module (or the one you replaced)
import sys
if 'tqdm' in sys.modules:
    setattr(sys.modules['tqdm'], 'format_sizeof', _dummy_format_sizeof)

# 1. Set up the environment
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the OPT-6.7B model with 16-bit precision (half-precision, FP16)
model_name = "facebook/opt-6.7b"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,  # Specify 16-bit precision
    device_map="auto"           # Automatically manage model distribution on GPU/CPU
)

# Load the corresponding tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Ensure the model is in evaluation mode
model.eval()

# If the tokenizer does not have a padding token, set one
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Model loaded, using {model.dtype} precision")
print(f"Model is on device: {next(model.parameters()).device}")

# Display model memory usage
def get_model_size(model):
    """Calculate model size (GB)"""
    model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**3
    return model_size

model_size_gb = get_model_size(model)
print(f"Model size: {model_size_gb:.2f} GB")

# If the tokenizer does not have a padding token, set one
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# 4. Set up hook functions to collect neuron activations
activations = {}
activated_neurons = {}
ACTIVATION_THRESHOLD = 0  # Set activation threshold

def get_ffn_activation(name):
    """Hook function to capture FFN activations"""
    def hook(module, input, output):
        # Save activation values
        activations[name] = output.detach().cpu()

        # Identify activated neurons (neurons exceeding the threshold)
        if output.dim() >= 2:
            # If it's a 3D tensor [batch_size, seq_len, hidden_dim]
            if output.dim() == 3:
                # For each sample and each position, find neurons with activation exceeding the threshold
                activation_mask = (output > ACTIVATION_THRESHOLD).detach().cpu()
                activated_neurons[name] = activation_mask
            # If it's a 2D tensor [batch_size, hidden_dim]
            elif output.dim() == 2:
                activation_mask = (output > ACTIVATION_THRESHOLD).detach().cpu()
                activated_neurons[name] = activation_mask
    return hook

# 5. Register hooks for the model
hooks = []

# Analyze the structure of the first layer to understand the location of the activation function
sample_layer = model.model.decoder.layers[0]
print("Layer Structure Analysis:")
for name, module in sample_layer.named_modules():
    print(f"  - {name}: {type(module).__name__}")

# Register hooks for FFN layers - targeting the OPT model architecture
for i, layer in enumerate(model.model.decoder.layers):
    # Find the activation function in the OPT model
    activation_found = False

    # Find activation functions in all modules
    for name, module in layer.named_modules():
        # Look for GELU or ReLU activation functions
        if isinstance(module, torch.nn.GELU) or isinstance(module, torch.nn.ReLU):
            print(f"Registering hook for module {name} in layer {i}")
            hook = module.register_forward_hook(get_ffn_activation(f'layer_{i}_{name}'))
            hooks.append(hook)
            activation_found = True

    if not activation_found:
        print(f"Warning: Activation function not found in layer {i}")

# 7. Collect sample activation data
def process_sample(sample, sample_id):
    """Process a single sample and collect activations"""
    # Clear previous activations
    activations.clear()
    activated_neurons.clear()

    # Prepare input for the sample
    encodings = tokenizer(sample['text'], return_tensors="pt", truncation=True, max_length=512)
    encodings = {k: v.to(device) for k, v in encodings.items()}

    # Run the model
    with torch.no_grad():
        outputs = model(**encodings)

    # Collect activated neurons
    sample_activated_neurons = {}
    for key, value in activated_neurons.items():
        try:
            if len(value.shape) >= 2:
                # Remove batch dimension
                mask = value.squeeze(0).numpy()

                # For each position, record the indices of activated neurons
                if len(mask.shape) == 2:  # [seq_len, hidden_dim]
                    position_neurons = {}
                    for pos in range(mask.shape[0]):
                        active_indices = np.where(mask[pos])[0]
                        if len(active_indices) > 0:
                            position_neurons[pos] = active_indices.tolist()

                    sample_activated_neurons[key] = position_neurons
                elif len(mask.shape) == 1:  # [hidden_dim]
                    active_indices = np.where(mask)[0]
                    if len(active_indices) > 0:
                        sample_activated_neurons[key] = {0: active_indices.tolist()}
        except Exception as e:
            pass

    # Return sample information
    return {
        'sample_id': sample_id,
        'text': sample['text'],
        'label': sample['label'],
        'activated_neurons': sample_activated_neurons,
        'input_ids': encodings['input_ids'][0].cpu().numpy(),
    }

def collect_activations(samples, batch_size=10):
    """Collect activation data for samples in batches"""
    results = []

    # Process samples in batches to save memory
    for i in range(0, len(samples), batch_size):
        batch = samples[i:i+batch_size]
        batch_results = []

        for j, sample in enumerate(batch):
            try:
                result = process_sample(sample, i + j)
                batch_results.append(result)
            except Exception as e:
                pass

        results.extend(batch_results)

        # Clean up memory
        if i + batch_size < len(samples):  # Not the last batch
            del batch_results
            gc.collect()
            torch.cuda.empty_cache()

    return results

# 8. Analyze neuron activation patterns
def analyze_neuron_activation_patterns(member_neurons, nonmember_neurons):
    """Analyze the activation patterns of neurons in member and non-member samples"""
    results = {}

    # Get all layer names
    layer_names = set()
    for sample in member_neurons + nonmember_neurons:
        layer_names.update(sample['activated_neurons'].keys())

    # Analyze layer by layer
    for layer_name in layer_names:
        # Count activated neurons in member and non-member samples
        member_neuron_counts = Counter()
        nonmember_neuron_counts = Counter()

        # Activated neurons in member samples
        for sample in member_neurons:
            if layer_name in sample['activated_neurons']:
                for pos_neurons in sample['activated_neurons'][layer_name].values():
                    member_neuron_counts.update(pos_neurons)

        # Activated neurons in non-member samples
        for sample in nonmember_neurons:
            if layer_name in sample['activated_neurons']:
                for pos_neurons in sample['activated_neurons'][layer_name].values():
                    nonmember_neuron_counts.update(pos_neurons)

        # Calculate the activation frequency for each neuron
        member_freq = {n: count / len(member_neurons) for n, count in member_neuron_counts.items()}
        nonmember_freq = {n: count / len(nonmember_neurons) for n, count in nonmember_neuron_counts.items()}

        # Identify neurons predominantly activated in member samples
        member_dominant = {}
        for neuron, freq in member_freq.items():
            if neuron not in nonmember_freq or freq > nonmember_freq[neuron] * 1.5:
                member_dominant[neuron] = freq

        # Identify neurons predominantly activated in non-member samples
        nonmember_dominant = {}
        for neuron, freq in nonmember_freq.items():
            if neuron not in member_freq or freq > member_freq[neuron] * 1.5:
                nonmember_dominant[neuron] = freq

        # Identify neurons frequently activated in both types of samples
        common_neurons = {}
        for neuron in set(member_freq.keys()) & set(nonmember_freq.keys()):
            if neuron not in member_dominant and neuron not in nonmember_dominant:
                common_neurons[neuron] = (member_freq[neuron], nonmember_freq[neuron])

        # Save the results
        results[layer_name] = {
            'member_dominant': member_dominant,
            'nonmember_dominant': nonmember_dominant,
            'common_neurons': common_neurons,
            'member_counts': member_neuron_counts,
            'nonmember_counts': nonmember_neuron_counts,
            'member_freq': member_freq,
            'nonmember_freq': nonmember_freq
        }

    return results

# 9. Build reference patterns
def build_reference_patterns(train_samples):
    """Build reference activation patterns using training samples"""
    # Separate member and non-member samples
    member_samples = [s for s in train_samples if s['label'] == 1]
    nonmember_samples = [s for s in train_samples if s['label'] == 0]

    # Analyze activation pattern differences
    reference_patterns = analyze_neuron_activation_patterns(member_samples, nonmember_samples)

    return reference_patterns

# 10. Calculate discrimination score for each layer
def calculate_layer_discrimination_scores(reference_patterns):
    """Calculate the discriminative power score for each layer"""
    layer_scores = {}

    # Calculate the discriminative power score for each layer
    for layer_name, data in reference_patterns.items():
        # Simple discrimination score: number of member-dominant neurons minus number of non-member-dominant neurons
        # If the score is greater than 0, it indicates the layer is more inclined to identify member samples
        member_dominant_count = len(data['member_dominant'])
        nonmember_dominant_count = len(data['nonmember_dominant'])

        discrimination_score = member_dominant_count - nonmember_dominant_count
        layer_scores[layer_name] = discrimination_score

    return layer_scores

# 11. Select the most discriminative layers
def select_discriminative_layers(layer_scores, top_n=10):
    """Select the most discriminative layers"""
    # Sort by discrimination score
    sorted_scores = sorted(layer_scores.items(), key=lambda x: abs(x[1]), reverse=True)

    # Select the top N layers
    selected_layers = [layer for layer, score in sorted_scores[:top_n]]

    return selected_layers

# 12. Membership prediction based on relative ratio
def predict_membership_by_relative_ratio(sample, reference_patterns, discriminative_layers, threshold=1.0):
    """Predict whether a single sample is a member using the relative ratio method"""
    if not sample['activated_neurons']:
        return 0, 0, 0  # If there is no activation data, default to non-member

    # Initialize relative ratios
    layers_counted = 0
    total_member_ratio = 0
    total_nonmember_ratio = 0

    # Analyze layer by layer
    for layer_name in discriminative_layers:
        if layer_name not in sample['activated_neurons'] or layer_name not in reference_patterns:
            continue

        # Get the reference pattern for this layer
        layer_data = reference_patterns[layer_name]

        # Get all neurons activated by this sample in this layer
        sample_neurons = set()
        for pos_neurons in sample['activated_neurons'][layer_name].values():
            sample_neurons.update(pos_neurons)

        if not sample_neurons:
            continue

        # Calculate the relative overlap with member-dominant neurons
        member_dominant_set = set(layer_data['member_dominant'].keys())
        member_overlap = len(sample_neurons.intersection(member_dominant_set))

        # Calculate the relative overlap with non-member-dominant neurons
        nonmember_dominant_set = set(layer_data['nonmember_dominant'].keys())
        nonmember_overlap = len(sample_neurons.intersection(nonmember_dominant_set))

        # Calculate relative ratios
        member_ratio = member_overlap / len(member_dominant_set) if len(member_dominant_set) > 0 else 0
        nonmember_ratio = nonmember_overlap / len(nonmember_dominant_set) if len(nonmember_dominant_set) > 0 else 0

        # Accumulate relative ratios
        total_member_ratio += member_ratio
        total_nonmember_ratio += nonmember_ratio
        layers_counted += 1

    # If there are no valid layers, default to non-member
    if layers_counted == 0:
        return 0, 0, 0

    # Calculate average relative ratios
    avg_member_ratio = total_member_ratio / layers_counted
    avg_nonmember_ratio = total_nonmember_ratio / layers_counted

    # Calculate the ratio of relative ratios
    if avg_nonmember_ratio == 0:
        ratio = float('inf')  # Avoid division by zero
    else:
        ratio = avg_member_ratio / avg_nonmember_ratio

    # Return the prediction and related ratios
    return 1 if ratio >= threshold else 0, avg_member_ratio, avg_nonmember_ratio

# 14. Find the best threshold on the validation set
def find_best_relative_ratio_threshold(val_samples, reference_patterns, discriminative_layers):
    """Find the best threshold on the validation set using the relative ratio method"""
    # Calculate relative ratios for all validation samples
    ratios = []
    labels = []

    for sample in val_samples:
        pred, member_ratio, nonmember_ratio = predict_membership_by_relative_ratio(
            sample, reference_patterns, discriminative_layers, threshold=1.0
        )

        # Calculate the ratio of relative ratios
        ratio = float('inf') if nonmember_ratio == 0 else member_ratio / nonmember_ratio

        ratios.append(ratio)
        labels.append(sample['label'])

    # Filter out infinite values
    filtered_ratios = []
    filtered_labels = []
    for r, l in zip(ratios, labels):
        if r != float('inf') and not np.isnan(r):
            filtered_ratios.append(r)
            filtered_labels.append(l)

    # Create candidate thresholds
    min_ratio = min(filtered_ratios)
    max_ratio = max(filtered_ratios)

    # Generate uniformly distributed thresholds
    candidate_thresholds = list(np.linspace(min_ratio, max_ratio, 100))
    # Add some important threshold points
    candidate_thresholds += [0.5, 1.0, 1.5, 2.0, 3.0, 5.0]
    candidate_thresholds = sorted(set(candidate_thresholds))

    # Find the best threshold
    best_threshold = 1.0
    best_accuracy = 0
    best_metrics = None

    results = []
    for threshold in candidate_thresholds:
        # Make predictions using the current threshold
        predictions = [1 if r >= threshold else 0 for r in ratios]

        # Calculate performance metrics
        accuracy = accuracy_score(labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')

        # Record the results
        results.append({
            'threshold': threshold,
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1
        })

        # Update the best threshold
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_threshold = threshold
            best_metrics = {
                'accuracy': accuracy,
                'precision': precision,
                'recall': recall,
                'f1': f1
            }

    return best_threshold, best_metrics

# Calculate AUC for the test set
def calculate_auc(test_samples, reference_patterns, discriminative_layers, batch_size=10):
    """Calculate the AUC value for the test set"""
    # Process test samples in batches
    all_ratios = []
    all_labels = []

    for i in range(0, len(test_samples), batch_size):
        batch = test_samples[i:i+batch_size]

        # Collect activation data
        batch_acts = collect_activations(batch, batch_size=batch_size)

        for sample in batch_acts:
            # Calculate relative ratios and predict
            _, member_ratio, nonmember_ratio = predict_membership_by_relative_ratio(
                sample, reference_patterns, discriminative_layers
            )

            # Calculate the ratio of relative ratios
            ratio = float('inf') if nonmember_ratio == 0 else member_ratio / nonmember_ratio

            # Record the results
            all_ratios.append(ratio)
            all_labels.append(sample['label'])

        # Release memory
        del batch_acts
        gc.collect()
        torch.cuda.empty_cache()

    # Filter out infinite values for ROC analysis
    filtered_ratios = []
    filtered_labels = []
    for ratio, label in zip(all_ratios, all_labels):
        if ratio != float('inf') and not np.isnan(ratio):
            filtered_ratios.append(ratio)
            filtered_labels.append(label)

    # Calculate AUC
    fpr, tpr, _ = roc_curve(filtered_labels, filtered_ratios)
    roc_auc = auc(fpr, tpr)

    return roc_auc

# Simply iterate over samples without using tqdm
def iterate_samples(samples):
    """Simply iterate over samples, avoiding tqdm"""
    return samples

# Process WikiMIA dataset
def process_wikimia(thresholds=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0]):
    """Process all thresholds for the WikiMIA dataset"""
    global ACTIVATION_THRESHOLD

    print("\n===== Dataset: WikiMIA =====")

    # Load the WikiMIA dataset
    try:
        # First, try to load the prepared dataset directly
        ds_dev = load_from_disk("data/ds_dev")
        ds_test = load_from_disk("data/ds_test")
        print("Loaded prepared dataset from disk")
    except:
        # If the prepared dataset is not found, process from scratch
        print("Prepared dataset not found, processing WikiMIA data from scratch...")

        # Load the WikiMIA dataset
        ds = load_dataset("swj0419/WikiMIA")

        # Select only the length32 data
        length32_data = ds["WikiMIA_length32"]

        # Separate member (label=1) and non-member (label=0) samples, and rename the 'input' field to 'text'
        member_samples = []
        nonmember_samples = []

        for item in length32_data:
            # Create a new sample dictionary, changing the 'input' field to 'text'
            new_item = {
                'text': item['input'],
                'label': item['label']
            }
            # Keep other existing fields
            for key, value in item.items():
                if key not in ['input', 'text', 'label']:
                    new_item[key] = value

            # Classify based on the label
            if item['label'] == 1:
                member_samples.append(new_item)
            else:
                nonmember_samples.append(new_item)

        print(f"Number of member samples: {len(member_samples)}")
        print(f"Number of non-member samples: {len(nonmember_samples)}")

        # Adjust dataset sizes
        # Use the number of available samples to determine the split
        available_member_count = len(member_samples)
        available_nonmember_count = len(nonmember_samples)

        # Set the number of member samples in the dev and test sets
        # Considering there are only 387 member samples, we can take 140 for dev and the remaining 247 for test
        dev_member_count = 140
        test_member_count = available_member_count - dev_member_count  # should be 247

        # Set the same number of non-member samples to maintain balance
        dev_nonmember_count = dev_member_count
        test_nonmember_count = test_member_count

        # Set a random seed for reproducibility
        np.random.seed(42)

        # Split the member samples
        dev_members = member_samples[:dev_member_count]
        test_members = member_samples[dev_member_count:]

        # Split the non-member samples
        dev_nonmembers = nonmember_samples[:dev_nonmember_count]
        test_nonmembers = nonmember_samples[dev_nonmember_count:dev_nonmember_count+test_nonmember_count]

        # Combine the dev set
        ds_dev_list = dev_members + dev_nonmembers
        np.random.shuffle(ds_dev_list)  # Shuffle the order

        # Combine the test set
        ds_test_list = test_members + test_nonmembers
        np.random.shuffle(ds_test_list)  # Shuffle the order

        # Verify dataset sizes and label balance
        print(f"Dev set size: {len(ds_dev_list)}")
        print(f"Number of member samples in Dev set: {sum(1 for item in ds_dev_list if item['label'] == 1)}")
        print(f"Number of non-member samples in Dev set: {sum(1 for item in ds_dev_list if item['label'] == 0)}")

        print(f"Test set size: {len(ds_test_list)}")
        print(f"Number of member samples in Test set: {sum(1 for item in ds_test_list if item['label'] == 1)}")
        print(f"Number of non-member samples in Test set: {sum(1 for item in ds_test_list if item['label'] == 0)}")

        # Convert to Dataset format
        ds_dev = Dataset.from_list(ds_dev_list)
        ds_test = Dataset.from_list(ds_test_list)

        # Create the data directory if it doesn't exist
        os.makedirs("data", exist_ok=True)

        # Save to disk
        ds_dev.save_to_disk("data/ds_dev")
        ds_test.save_to_disk("data/ds_test")

        print("Datasets have been saved to the data/ds_dev and data/ds_test directories")

    # Process for each threshold
    for threshold in thresholds:
        # Set the current threshold
        ACTIVATION_THRESHOLD = threshold

        # 1. Get member and non-member samples from the development set
        dev_member = [item for item in ds_dev if item['label'] == 1]
        dev_nonmember = [item for item in ds_dev if item['label'] == 0]

        # 2. Split the development set evenly into training and validation sets
        train_member = dev_member[:len(dev_member)//2]
        train_nonmember = dev_nonmember[:len(dev_nonmember)//2]

        val_member = dev_member[len(dev_member)//2:]
        val_nonmember = dev_nonmember[len(dev_nonmember)//2:]

        # 3. Collect activation data
        train_member_acts = collect_activations(train_member)
        train_nonmember_acts = collect_activations(train_nonmember)

        val_member_acts = collect_activations(val_member)
        val_nonmember_acts = collect_activations(val_nonmember)

        # 4. Build reference activation patterns
        reference_patterns = build_reference_patterns(
            train_member_acts + train_nonmember_acts
        )

        # 5. Calculate layer discrimination scores
        layer_scores = calculate_layer_discrimination_scores(reference_patterns)

        # 6. Select the most discriminative layers
        discriminative_layers = select_discriminative_layers(layer_scores, top_n=10)

        # 7. Find the best threshold on the validation set using the relative ratio method
        val_samples = val_member_acts + val_nonmember_acts
        best_threshold, val_metrics = find_best_relative_ratio_threshold(
            val_samples,
            reference_patterns,
            discriminative_layers
        )

        # 8. Calculate test set AUC
        test_auc = calculate_auc(list(ds_test), reference_patterns, discriminative_layers)

        # 9. Print only the dataset, threshold, and test set AUC value
        print(f"Dataset: WikiMIA, Threshold: {threshold:.1f}, Test Set AUC = {test_auc:.4f}")

        # 10. Clean up memory
        del train_member_acts, train_nonmember_acts, val_member_acts, val_nonmember_acts
        del reference_patterns, layer_scores, discriminative_layers, val_samples
        gc.collect()
        torch.cuda.empty_cache()

    # Release memory after processing the current dataset
    del ds_dev, ds_test
    gc.collect()
    torch.cuda.empty_cache()

# Main function
if __name__ == "__main__":
    try:
        # Set the activation thresholds to be evaluated
        activation_thresholds = [1.0]

        # Only process the WikiMIA dataset
        process_wikimia(thresholds=activation_thresholds)

        # After processing, remove the hooks
        for hook in hooks:
            hook.remove()

    except Exception as e:
        print(f"An error occurred during execution: {str(e)}")
        # Ensure hooks are removed
        for hook in hooks:
            try:
                hook.remove()
            except:
                pass
        raise e