In [None]:
import random
import re
import torch
from datasets import load_dataset
from transformers import GPTNeoXForCausalLM, AutoTokenizer
import matplotlib.pyplot as plt
from tqdm import tqdm
import gc
import os
import warnings

# Suppress specific warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", message="TypedStorage is deprecated")

# Set environment variable for PyTorch to handle memory fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"

# Configuration
BATCH_SIZE = 1  # Set to 1 to minimize GPU memory usage
MAX_NEW_TOKENS = 10
MAX_CONTEXT_TOKENS = 2048

# Load dataset
print("Loading dataset...")
ds = load_dataset("NeelNanda/pile-10k", split="train")

# Sample 1000 sentences
print("Sampling 1000 sentences...")
random.seed(42)
sampled_indices = random.sample(range(len(ds)), 1000)
sampled_sentences = [ds[i]["text"] for i in sampled_indices]

# Load a stable tokenizer for ground truth token counting
print("Loading base tokenizer for ground truth...")
base_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m-deduped")

# Compute actual token counts
print("Computing actual token counts...")
actual_counts = [len(base_tokenizer.encode(s, add_special_tokens=False)) for s in sampled_sentences]

# Define steps (every 10k steps)
steps = list(range(0, 150000, 10000))
if 143000 not in steps:
    steps.append(143000)

# Prepare GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Initialize lists to store results
accuracies = []
mean_abs_diffs = []

# Function to split sentences into batches
def split_into_batches(sentences, batch_size):
    for i in range(0, len(sentences), batch_size):
        yield sentences[i:i + batch_size]

# Iterate over each step with a single progress bar
for step in tqdm(steps, desc="Processing Steps"):
    try:
        # Load model for the current step with half precision to save memory
        revision = f"step{step}"
        print(f"\nLoading model for revision: {revision}")
        model = GPTNeoXForCausalLM.from_pretrained(
            "EleutherAI/pythia-70m-deduped",
            revision=revision,
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
            low_cpu_mem_usage=True
        ).to(device)
        
        model.eval()

        # Load tokenizer for this revision
        tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m-deduped", revision=revision)

        # Ensure tokenizer has a pad_token
        if tokenizer.pad_token is None:
            print("Tokenizer does not have a pad_token. Setting pad_token to eos_token.")
            tokenizer.pad_token = tokenizer.eos_token
            # If necessary, resize model embeddings to accommodate new pad_token
            model.resize_token_embeddings(len(tokenizer))

        predicted_counts = []

        # Process sentences in batches without an inner progress bar
        for batch_sentences in split_into_batches(sampled_sentences, BATCH_SIZE):
            prompts = []
            for sentence in batch_sentences:
                prompt = (
                    "Answer only the last question with a single number only. Do not say any other words before "
                    "that number. Do not say 'Sure' or anything like that. Your task is to count how many tokens a "
                    "sentence has. The sentence is always written between <start> and <end> delimiters, e.g. <start> "
                    "sentence <end>. Don’t count the surrounding delimiters.\n\n"
                    "How many tokens does the sentence have?\n"
                    f"<start> {sentence} <end>"
                )
                prompts.append(prompt)

            # Tokenize prompts with truncation to fit within the context length
            encodings = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, 
                                  max_length=MAX_CONTEXT_TOKENS - MAX_NEW_TOKENS)
            input_ids = encodings['input_ids'].to(device)
            attention_mask = encodings['attention_mask'].to(device)

            # Generate responses
            with torch.no_grad():
                outputs = model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    max_new_tokens=MAX_NEW_TOKENS,
                    do_sample=False,
                    pad_token_id=tokenizer.pad_token_id
                )

            # Decode and extract predicted counts
            for i in range(len(prompts)):
                # Calculate the start index for decoding
                prompt_length = (input_ids[i] != tokenizer.pad_token_id).sum().item()
                generated_tokens = outputs[i][prompt_length:]
                output_str = tokenizer.decode(
                    generated_tokens,
                    skip_special_tokens=True
                )
                match = re.search(r"\d+", output_str.strip())
                if match:
                    pred_count = int(match.group(0))
                else:
                    pred_count = -1  # If no number found, use -1 as a placeholder
                predicted_counts.append(pred_count)

        # Compute accuracy
        exact_matches = sum(1 for a, p in zip(actual_counts, predicted_counts) if a == p)
        accuracy = exact_matches / len(actual_counts)

        # Compute mean absolute difference normalized by actual count
        diffs = [(abs(a - p) / a) if a != 0 else 0 for a, p in zip(actual_counts, predicted_counts)]
        mean_diff = sum(diffs) / len(diffs)

        # Append results
        accuracies.append((step, accuracy))
        mean_abs_diffs.append((step, mean_diff))

        # Print the results for the current step
        print(f"Step {step}: Accuracy = {accuracy*100:.2f}%, Mean Normalized Absolute Difference = {mean_diff:.4f}")

    except torch.cuda.OutOfMemoryError as oom:
        print(f"OutOfMemoryError at step {step}: {oom}")
    except Exception as e:
        print(f"Error at step {step}: {e}")
    finally:
        # Free up GPU memory
        if 'model' in locals():
            del model
        if 'tokenizer' in locals():
            del tokenizer
        torch.cuda.empty_cache()
        gc.collect()

# Plot accuracy and mean difference vs steps
steps_plot = [s for s, _ in accuracies]
y_acc = [a for _, a in accuracies]
y_diff = [d for _, d in mean_abs_diffs]

plt.figure(figsize=(14, 6))

plt.subplot(1, 2, 1)
plt.plot(steps_plot, y_acc, marker='o', linestyle='-')
plt.title("Accuracy vs Steps")
plt.xlabel("Step")
plt.ylabel("Accuracy")
plt.ylim(0, 1)
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(steps_plot, y_diff, marker='o', color='red', linestyle='-')
plt.title("Mean Normalized Absolute Difference vs Steps")
plt.xlabel("Step")
plt.ylabel("Mean Normalized Absolute Difference")
plt.grid(True)

plt.tight_layout()
plt.show()


Loading dataset...
