# inits

In [1]:
!python -m pip install -q accelerate bitsandbytes transformers torch datasets evaluate tabulate fsspec


In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch, time
from config import ModelArgs, HF_MODEL_PATH, HF_TOKENIZER_PATH
import warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
args = ModelArgs()
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True

# Using torch.float32

In [4]:
model = AutoModelForCausalLM.from_pretrained(
    HF_MODEL_PATH,
    device_map="auto",
    torch_dtype=torch.float32
)
print(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

# using fp4 quant (bitsandbytes)

In [5]:
# Load model in fp4 using bitsandbytes
fp4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="fp4",
)
fp4_model = AutoModelForCausalLM.from_pretrained(
    HF_MODEL_PATH,
    device_map="auto",
    torch_dtype=torch.float32,
    quantization_config=fp4_config
)
print(fp4_model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear4bit(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear4bit(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear4bit(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear4bit(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear4bit(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), 

# using nf4 quant bitsandbytes

In [6]:
# Load model in nf4 using bitsandbytes
nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
)
nf4_model = AutoModelForCausalLM.from_pretrained(
    HF_MODEL_PATH,
    device_map="auto",
    torch_dtype=torch.float32,
    quantization_config=nf4_config
)
print(nf4_model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear4bit(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear4bit(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear4bit(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear4bit(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear4bit(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), 

# Running generation

In [7]:
tokenizer = AutoTokenizer.from_pretrained(HF_TOKENIZER_PATH, trust_remote_code=True)


prompt = "Today was a perfect day"


In [8]:
inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
start = time.time()
output_ids = model.generate(
    **inputs,
    max_new_tokens=args.max_new_tokens,
    do_sample=args.do_sample,
    pad_token_id=tokenizer.eos_token_id
)
elapsed = time.time() - start
tokens_generated = output_ids.shape[-1] - inputs.input_ids.shape[-1]
print(f"with torch.float32")
fp32_out = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(fp32_out)
print()
print(f"Token count: {tokens_generated}, elapsed: {elapsed:.2f}s, {tokens_generated/elapsed:.0f} tokens/s")

with torch.float32
Today was a perfect day for a walk. The weather was sunny and warm, the air was crisp and clean, and the sky was a deep blue. I was in a good mood, and I was in the mood to walk. I had a long list of things to do, but I didn’t want to do them all at once. I wanted to take my time and enjoy the walk.
I started out by walking down the street. I walked for a few minutes, and then I turned around and walked back. I walked for a few more minutes, and then I turned around again and walked back. I walked for a few more minutes, and then I turned around again and walked back. I walked for a few more minutes, and then I turned around again

Token count: 150, elapsed: 1.38s, 109 tokens/s


In [9]:
inputs = tokenizer(prompt, return_tensors='pt').to(fp4_model.device)
start = time.time()
output_ids = fp4_model.generate(
    **inputs,
    max_new_tokens=args.max_new_tokens,
    do_sample=args.do_sample,
    pad_token_id=tokenizer.eos_token_id
)
elapsed = time.time() - start
tokens_generated = output_ids.shape[-1] - inputs.input_ids.shape[-1]
print(f"with fp4")
fp4_out = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(fp4_out)
print(f"Token count: {tokens_generated}, elapsed: {elapsed:.2f}s, {tokens_generated/elapsed:.0f} tokens/s")

with fp4
Today was a perfect day to go to the beach. The water was calm and the waves were small. I was able to walk on the beach and not get my feet wet. I was able to walk on the beach and not get my feet wet. I was able to walk on the beach and not get my feet wet. I was able to walk on the beach and not get my feet wet. I was able to walk on the beach and not get my feet wet. I was able to walk on the beach and not get my feet wet. I was able to walk on the beach and not get my feet wet. I was able to walk on the beach and not get my feet wet. I was able to walk on the beach and not get my feet wet
Token count: 150, elapsed: 0.94s, 160 tokens/s


In [10]:
inputs = tokenizer(prompt, return_tensors='pt').to(nf4_model.device)
start = time.time()
output_ids = nf4_model.generate(
    **inputs,
    max_new_tokens=args.max_new_tokens,
    do_sample=args.do_sample,
    pad_token_id=tokenizer.eos_token_id
)
elapsed = time.time() - start
tokens_generated = output_ids.shape[-1] - inputs.input_ids.shape[-1]
print(f"with nf4")
nf4_out = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(nf4_out)
print(f"Token count: {tokens_generated}, elapsed: {elapsed:.2f}s, {tokens_generated/elapsed:.0f} tokens/s")

with nf4
Today was a perfect day to go to the beach. We went to the beach in the morning and then went to the park in the afternoon. We had a great time. We went to the park and played on the swings and the slides. We also went to the playground and played on the jungle gym. We had a great time. We also went to the beach and played on the swings and the slides. We also went to the playground and played on the jungle gym. We had a great time. We also went to the beach and played on the swings and the slides. We also went to the playground and played on the jungle gym. We had a great time. We also went to the beach and played on the swings and the slides. We also went to
Token count: 150, elapsed: 0.93s, 162 tokens/s


In [11]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from math import exp

def compute_perplexity(model, tokenizer, texts, max_length=512):
    model.eval()
    ppl_scores = []
    for text in texts:
        enc = tokenizer(text, return_tensors='pt', truncation=True, max_length=max_length)
        input_ids = enc.input_ids.to(model.device)
        with torch.no_grad():
            outputs = model(input_ids, labels=input_ids)
            loss = outputs.loss

        ppl_scores.append(exp(loss.item()))
    return ppl_scores

In [12]:
texts = [
    "Once upon a time",
    "The capital of France is",
    "What is 17 * 28?"
]


ppl_fp32 = compute_perplexity(model, tokenizer, texts)
ppl_fp4 = compute_perplexity(fp4_model, tokenizer, texts)
ppl_nf4 = compute_perplexity(nf4_model, tokenizer, texts)
print("FP32 PPL:", ppl_fp32)
print("FP4  PPL:", ppl_fp4)
print("NF4  PPL:", ppl_nf4)


FP32 PPL: [14.894348644381546, 47.592817913932805, 57.660402623079534]
FP4  PPL: [15.627858879456415, 35.94554888214959, 62.58180678166403]
NF4  PPL: [13.277682313925535, 54.149774907452674, 60.16393450325215]


In [13]:
from tabulate import tabulate

rows = []
for i, text in enumerate(texts):
    rows.append([
        text,
        round(ppl_fp32[i], 2),
        round(ppl_fp4[i], 2),
        round(ppl_nf4[i], 2)
    ])

print(tabulate(rows, headers=["Text", "FP32 PPL", "FP4 PPL", "NF4 PPL"]))

Text                        FP32 PPL    FP4 PPL    NF4 PPL
------------------------  ----------  ---------  ---------
Once upon a time               14.89      15.63      13.28
The capital of France is       47.59      35.95      54.15
What is 17 * 28?               57.66      62.58      60.16


In [14]:
from datasets import load_dataset

ds = load_dataset("cais/mmlu", "all")

# eval function

In [15]:
def format_mmlu_prompt(question, choices):
    formatted_choices = "\n".join([f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)])
    return f"Question: {question}\n\nChoices:\n{formatted_choices}\n\nAnswer:"

def evaluate_model(model, tokenizer, dataset, num_samples=None):
    model.eval()
    correct = 0
    total = 0
    
    # Take a subset if num_samples is specified
    samples = dataset if num_samples is None else dataset.select(range(num_samples))
    
    results = []
    for item in samples:
        # Format the prompt
        prompt = format_mmlu_prompt(item['question'], item['choices'])
        
        # Tokenize and generate
        inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=1,
                pad_token_id=tokenizer.eos_token_id,
                do_sample=False
            )
        
        # Get the predicted answer
        generated = tokenizer.decode(outputs[0][-1:], skip_special_tokens=True).strip()
        
        # Map the output to A, B, C, D
        pred_idx = None
        for i, letter in enumerate(['A', 'B', 'C', 'D']):
            if letter in generated:
                pred_idx = i
                break
        
        # If no valid letter found, count as wrong
        if pred_idx is None:
            correct_answer = chr(65 + item['answer'])
            results.append({
                'question': item['question'],
                'prediction': generated,
                'correct': correct_answer,
                'is_correct': False
            })
            continue
            
        # Check if correct
        is_correct = pred_idx == item['answer']
        if is_correct:
            correct += 1
        total += 1
        
        results.append({
            'question': item['question'],
            'prediction': chr(65 + pred_idx),
            'correct': chr(65 + item['answer']),
            'is_correct': is_correct
        })
    
    accuracy = correct / total if total > 0 else 0
    return accuracy, results

# running evals

In [16]:
# Load the dataset
print("Loading MMLU Abstract Algebra dataset...")
eval_ds = ds["test"]  # Using test split for evaluation

# Number of samples to evaluate
num_eval_samples = 1000

print("\nEvaluating FP32 model...")
fp32_acc, fp32_results = evaluate_model(model, tokenizer, eval_ds, num_eval_samples)

print("\nEvaluating FP4 model...")
fp4_acc, fp4_results = evaluate_model(fp4_model, tokenizer, eval_ds, num_eval_samples)

print("\nEvaluating NF4 model...")
nf4_acc, nf4_results = evaluate_model(nf4_model, tokenizer, eval_ds, num_eval_samples)

Loading MMLU Abstract Algebra dataset...

Evaluating FP32 model...

Evaluating FP4 model...

Evaluating NF4 model...


In [17]:
# Print results
print(f"\nResults(% of {num_eval_samples}):")
print(f"FP32 Accuracy: {fp32_acc:.2%}")
print(f"FP4 Accuracy:  {fp4_acc:.2%}")
print(f"NF4 Accuracy:  {nf4_acc:.2%}")


Results(% of 1000):
FP32 Accuracy: 33.90%
FP4 Accuracy:  31.50%
NF4 Accuracy:  36.30%


In [18]:
# Print detailed results for first few examples
print("\nSample Results:")
print("\nFP32 Model:")
for result in fp32_results[:3]:
    print(f"Q: {result['question']}")
    print(f"Predicted: {result['prediction']}, Correct: {result['correct']}, Is Correct: {result['is_correct']}\n")


Sample Results:

FP32 Model:
Q: Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q.
Predicted: B, Correct: B, Is Correct: True

Q: Let p = (1, 2, 5, 4)(2, 3) in S_5 . Find the index of <p> in S_5.
Predicted: B, Correct: C, Is Correct: False

Q: Find all zeros in the indicated finite field of the given polynomial with coefficients in that field. x^5 + 3x^3 + x^2 + 2x in Z_5
Predicted: A, Correct: D, Is Correct: False

