In [1]:
# !pip install --upgrade transformers

In [2]:
!nvidia-smi

Wed May  7 04:01:40 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.154.05             Driver Version: 535.154.05   CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:27:00.0 Off |                    0 |
| N/A   33C    P0              77W / 400W |  19768MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [3]:
import ast
import json
import os
from pathlib import Path

import pandas as pd
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
MODEL_NAME = "Qwen/Qwen3-8B"

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
import os

in_path = Path("data_qwen3/source/mmlu_pro_stem.tsv").resolve()

out_path = Path("data_qwen3/out/qwen3_8b_mmlu_entropy.parquet").resolve()

if os.path.exists(out_path):
    df = pd.read_parquet(
        out_path,
    )
else:
    df = pd.read_csv(
        in_path,
        sep="\t",
        header=0,
    )
    df["options"] = df["options"].apply(ast.literal_eval)
    df = df.iloc[:2000]

option_ids = [str(i + 1) for i in range(20)]


def enumerate_question_and_options(question, options):
    options_str = "\n".join([f"{option_id}. {answer}".strip() for option_id, answer in zip(option_ids, options)])
    user_prompt = f"Question: {question.strip()}\nOptions:\n{options_str}\nChoose one of the answers. Write down ONLY the NUMBER of the correct answer and nothing else."
    return user_prompt

### Answer generation


In [6]:
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype="auto", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token_id = tokenizer.eos_token_id

Loading checkpoint shards: 100%|██████████| 5/5 [00:04<00:00,  1.21it/s]


In [None]:
def get_embeddings(model, tokenizer, text):
    try:
        inputs = tokenizer(text, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True, return_dict=True)
        # Layer number (last),  batch size
        batch_hidden_states = outputs.hidden_states[-1][0].float()
        pool_embeddings = {
            "min": batch_hidden_states.min(dim=0).values.cpu().numpy().tolist(),
            "max": batch_hidden_states.max(dim=0).values.cpu().numpy().tolist(),
            "mean": batch_hidden_states.mean(dim=0).cpu().numpy().tolist(),
        }
        return pool_embeddings
    # TODO: Investigate why it fails for Qwen 3B only for specific rows
    except:
        return None

In [8]:
from dataclasses import dataclass, field
from typing import Any


def compute_entropy_from_logits(logits: torch.Tensor) -> float:
    """
    Compute entropy from logits.

    Parameters:
    ----------
    logits : torch.Tensor
        Logits from the model.

    Returns:
    -------
    torch.Tensor
        Entropy values.
    """
    probabilities = torch.softmax(logits, dim=-1)
    log_probabilities = torch.log(probabilities + 1e-12)
    entropy = -torch.sum(probabilities * log_probabilities, dim=-1)
    return entropy.item()


@dataclass
class LogitSeqStats:
    # Generated token selected greedily (no randomness, next token - the most likely one)
    greedy_tokens: list[torch.Tensor] = field(default_factory=list)
    # List of entropies for every generated token
    entropies: list[float] = field(default_factory=list)
    # List of raw probabilities for logits with non-zero probabilities for every generated token
    every_token_stats: list[list[dict[str, Any]]] = field(default_factory=list)


def collect_logit_sequence_stats(logits: list[torch.Tensor]):
    """
    Parameters:
    ----------
    logits : torch.Tensor
        Logits for the entire generated sequence. Assumes batch size of 1.
        Pass here "scores" from "model.generate".
        Dim: list[tensor(1 x dictionary_size)]
    """
    stats = LogitSeqStats()
    for i in range(len(logits)):
        # generated token position, batch_dim
        token_logits = logits[i][0]
        token_entropy = compute_entropy_from_logits(token_logits)
        stats.entropies.append(token_entropy)

        probabilities = torch.softmax(token_logits, dim=-1)
        # Set small cut-off value
        mask = probabilities > 1e-5
        nonzero_prob_indices = torch.nonzero(mask)
        nonzero_probs = probabilities[nonzero_prob_indices]
        idx_prob_pairs_list = list(zip(nonzero_prob_indices.cpu().numpy(), nonzero_probs.cpu().numpy()))
        position_result = [
            {
                "token_idx": pair[0].item(),
                "token_prob": pair[1].item(),
            }
            for pair in idx_prob_pairs_list
        ]
        stats.every_token_stats.append(position_result)

        greedy_token = token_logits.argmax(dim=-1)
        stats.greedy_tokens.append(greedy_token)

    return stats


In [9]:
def check_answer_correct_mmlu(row, model_answer):
    try:
        return int(row["answer_index"]) + 1 == int(model_answer.strip())
    except:
        return False

In [None]:
import gc

invalid_answers = 0

field_response = "qwen_8b_response"
field_ans_token_index = "qwen_8b_ans_token_index"
field_ans_correct = "qwen_8b_ans_correct"
field_entropies_value = "qwen_8b_entropies"
field_every_token_info = "qwen_8b_every_token_info"
field_input_embeddings = "qwen_8b_input_embeddings"
field_think_embeddings = "qwen_8b_think_embeddings"
field_answer_embeddings = "qwen_8b_answer_embeddings"

if field_ans_correct not in df.columns:
    df[field_ans_correct] = False
if field_entropies_value not in df.columns:
    df[field_entropies_value] = ""
if field_every_token_info not in df.columns:
    df[field_every_token_info] = ""
if field_ans_token_index not in df.columns:
    df[field_ans_token_index] = -1
if field_response not in df.columns:
    df[field_response] = ""
if field_input_embeddings not in df.columns:
    df[field_input_embeddings] = ""
if field_think_embeddings not in df.columns:
    df[field_think_embeddings] = ""
if field_answer_embeddings not in df.columns:
    df[field_answer_embeddings] = ""

for index, row in tqdm(df.iterrows(), total=df.shape[0]):
    if df.at[index, field_ans_token_index] != -1:
        continue

    gc.collect()
    torch.cuda.empty_cache()

    sys_prompt = f"The following are multiple choice questions about {row['base_cluster']}. Write down ONLY the NUMBER of the correct answer and nothing else."
    user_prompt = enumerate_question_and_options(row["question"], row["options"])

    messages = [
        {"role": "system", "content": sys_prompt},
        {"role": "user", "content": user_prompt},
    ]
    formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
    max_new_tokens = 5000
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        output_scores=True,
        return_dict_in_generate=True,
        temperature=0.1,
        # top_p=None,
        # top_k=None,
        # do_sample=False,
        num_beams=1,
        pad_token_id=tokenizer.eos_token_id,
    )

    input_length = inputs.input_ids.shape[1]
    response_raw = outputs.sequences[0, input_length:]
    response_decoded = tokenizer.decode(response_raw, skip_special_tokens=True)

    df.at[index, field_response] = response_decoded

    logit_stats = collect_logit_sequence_stats(outputs.scores)

    df.at[index, field_entropies_value] = json.dumps(logit_stats.entropies)
    df.at[index, field_every_token_info] = json.dumps(logit_stats.every_token_stats)

    think_token_idx = tokenizer.convert_tokens_to_ids("</think>")

    think_text = ""
    answer_text = ""
    answer_token_idx = -1
    for i, token in enumerate(logit_stats.greedy_tokens):
        if token == think_token_idx:
            answer_token_idx = i + 1
            df[field_ans_token_index] = answer_token_idx

            think_text = tokenizer.decode(logit_stats.greedy_tokens[:answer_token_idx])
            think_embeddings = get_embeddings(model, tokenizer, think_text)
            if think_embeddings is not None:
                df.at[index, field_think_embeddings] = json.dumps(think_embeddings)

            answer_text = tokenizer.decode(logit_stats.greedy_tokens[answer_token_idx:], skip_special_tokens=True)
            answer_text = answer_text.strip()
            answer_embeddings = get_embeddings(model, tokenizer, answer_text)
            if answer_embeddings:
                df.at[index, field_answer_embeddings] = json.dumps(answer_embeddings)

            break

    input_embeddings = get_embeddings(model, tokenizer, formatted_prompt)
    if input_embeddings:
        df.at[index, field_input_embeddings] = json.dumps(input_embeddings)

    if answer_text in option_ids:
        # print(f"loop {index} -> after entropy: {model.get_memory_footprint(return_buffers=True) / 10**9} GB")
        df.at[index, field_ans_correct] = check_answer_correct_mmlu(row, answer_text)
    else:
        invalid_answers += 1

    if index % 500 == 0:
        df.to_parquet(out_path, compression="gzip")

df.to_parquet(out_path, compression="gzip")

  0%|          | 3/2000 [03:19<35:26:17, 63.88s/it]