In [None]:
# Import utils functions
from utils import load_model, get_final_representation
import numpy as np
import torch
import random

# Load model
MODEL_PATH = "/home/chashi/Desktop/Research/My Projects/models"
model, tokenizer = load_model(MODEL_PATH)

# Define factual and unfactual questions
factual_questions = [
    "What is the capital of France?",
    "How many days are in a year?",
    "What is 2 plus 2?",
    "Who wrote Romeo and Juliet?",
    "What is the chemical symbol for water?",
    "What planet is closest to the Sun?",
    "How many continents are there?",
    "What year did World War II end?",
    "What is the largest ocean on Earth?",
    "Who painted the Mona Lisa?"
]

unfactual_questions = [
    "What color is the sound of Tuesday?",
    "How many dreams fit in a teaspoon?",
    "What is the weight of my grandmother's favorite memory?",
    "Which number tastes the most like purple?",
    "What will I be thinking about on March 15, 2087?",
    "How fast do unicorns run?",
    "What is the temperature of invisible fire?",
    "Which emotion is exactly 7 inches tall?",
    "What is the secret ingredient in moonlight?",
    "How many wishes live in a broken clock?"
]

# Define roles
roles = [
    "You are a mathematics professor.",
    "You are a high school student.", 
    "You are a professional chef.",
    "You are a famous film star."
]

print("Creating dataset...")

# Create all question-role combinations
all_prompts = []
labels = []

# Base questions without roles (factual)
for question in factual_questions:
    all_prompts.append(question)
    labels.append("factual_base")

# Base questions without roles (unfactual)  
for question in unfactual_questions:
    all_prompts.append(question)
    labels.append("unfactual_base")

# Factual questions with roles
for question in factual_questions:
    for role in roles:
        prompt = f"{role} {question}"
        all_prompts.append(prompt)
        labels.append(f"factual_{roles.index(role)}")

# Unfactual questions with roles  
for question in unfactual_questions:
    for role in roles:
        prompt = f"{role} {question}"
        all_prompts.append(prompt)
        labels.append(f"unfactual_{roles.index(role)}")

print(f"Total prompts created: {len(all_prompts)}")

# Get LM head weights once
lm_head_weights = model.lm_head.weight.detach().cpu().float().numpy()

# Extract representations, logits, and LM head embeddings
representations = []
all_logits = []
all_lm_head_embeddings = []

for i, prompt in enumerate(all_prompts):
    if i % 10 == 0:
        print(f"Processing {i+1}/{len(all_prompts)}")
    
    # Get final representation
    repr_vec = get_final_representation(model, tokenizer, prompt)
    representations.append(repr_vec.numpy())
    
    # Get logits and LM head embedding for this representation
    device = next(model.parameters()).device
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
        # Get logits for the last token
        last_logits = outputs.logits[0, -1, :].cpu().float().numpy()
        all_logits.append(last_logits)
        
        # Get the nearest LM head embedding
        hidden_rep = repr_vec.numpy()
        similarities = np.dot(hidden_rep, lm_head_weights.T) / (
            np.linalg.norm(hidden_rep) * np.linalg.norm(lm_head_weights, axis=1)
        )
        nearest_idx = np.argmax(similarities)
        all_lm_head_embeddings.append(lm_head_weights[nearest_idx])

# Convert to numpy arrays
representations = np.array(representations)
all_logits = np.array(all_logits)
all_lm_head_embeddings = np.array(all_lm_head_embeddings)

print(f"Representations shape: {representations.shape}")
print(f"Logits shape: {all_logits.shape}")
print(f"LM head embeddings shape: {all_lm_head_embeddings.shape}")

# Get random samples from embedding and unembedding weights
print("Sampling embedding and unembedding weights...")

# Embedding weights (input embeddings)
embed_weights = model.model.embed_tokens.weight.detach().cpu().float().numpy()
vocab_size = embed_weights.shape[0]

# Unembedding weights (LM head)
unembed_weights = model.lm_head.weight.detach().cpu().float().numpy()

# Random sample indices
random_indices = random.sample(range(vocab_size), 10000)

embed_sample = embed_weights[random_indices]
unembed_sample = unembed_weights[random_indices]

print(f"Embedding sample shape: {embed_sample.shape}")
print(f"Unembedding sample shape: {unembed_sample.shape}")

# Save everything
print("Saving files...")

# Save all arrays
np.save('question_representations.npy', representations)
np.save('question_logits.npy', all_logits)
np.save('question_lm_head_embeddings.npy', all_lm_head_embeddings)
np.save('embedding_weights_sample.npy', embed_sample)
np.save('unembedding_weights_sample.npy', unembed_sample)

# Save questions and labels as text
with open('questions_and_prompts.txt', 'w') as f:
    f.write("FACTUAL QUESTIONS:\n")
    f.write("=" * 50 + "\n")
    for q in factual_questions:
        f.write(f"{q}\n")
    
    f.write("\nUNFACTUAL QUESTIONS:\n")
    f.write("=" * 50 + "\n")
    for q in unfactual_questions:
        f.write(f"{q}\n")
    
    f.write("\nROLES:\n")
    f.write("=" * 50 + "\n")
    for i, role in enumerate(roles):
        f.write(f"{i}: {role}\n")
    
    f.write("\nALL PROMPTS AND LABELS:\n")
    f.write("=" * 50 + "\n")
    for prompt, label in zip(all_prompts, labels):
        f.write(f"{label}: {prompt}\n")

# Save labels separately for easy loading
with open('labels.txt', 'w') as f:
    for label in labels:
        f.write(f"{label}\n")

print("Files saved:")
print("- question_representations.npy")
print("- question_logits.npy") 
print("- question_lm_head_embeddings.npy")
print("- embedding_weights_sample.npy") 
print("- unembedding_weights_sample.npy")
print("- questions_and_prompts.txt")
print("- labels.txt")

print(f"\nDataset summary:")
print(f"- {len(factual_questions)} factual questions")
print(f"- {len(unfactual_questions)} unfactual questions") 
print(f"- {len(roles)} roles + base (no role)")
print(f"- {len(all_prompts)} total prompt combinations")
print(f"- Representation dimension: {representations.shape[1]}")
print(f"- Logits dimension: {all_logits.shape[1]}")
print(f"- LM head embeddings dimension: {all_lm_head_embeddings.shape[1]}")

# Quick load test
print("\nTesting file loading...")
test_reps = np.load('question_representations.npy')
test_logits = np.load('question_logits.npy')
test_lm_embeddings = np.load('question_lm_head_embeddings.npy')
test_embed = np.load('embedding_weights_sample.npy')
test_unembed = np.load('unembedding_weights_sample.npy')

print(f"Loaded representations shape: {test_reps.shape}")
print(f"Loaded logits shape: {test_logits.shape}")
print(f"Loaded LM head embeddings shape: {test_lm_embeddings.shape}")
print(f"Loaded embedding sample shape: {test_embed.shape}")
print(f"Loaded unembedding sample shape: {test_unembed.shape}")
print("All files loaded successfully!")