In [None]:
!pip install datasets
!pip install evaluate
!pip install accelerate
!pip install bitsandbytes

Collecting datasets
  Downloading datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.1-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.4/491.4 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl (1

In [None]:
import json

import numpy as np
import torch
import torch.nn as nn
import bitsandbytes as bnb

import datasets
from datasets import load_dataset, load_from_disk
from datasets.arrow_dataset import Dataset
from evaluate import load

from transformers import pipeline
from transformers import BitsAndBytesConfig
from transformers import AutoModelForCausalLM
from transformers.pipelines.pt_utils import KeyDataset

import accelerate

from tqdm import tqdm


import matplotlib.pyplot as plt

In [None]:
model_id = "meta-llama/Llama-3.2-1B-Instruct"

pipe = pipeline('text-generation', model=model_id, torch_dtype=torch.bfloat16)
# LLama 3.2 has multiple eos_token_id. We use the "128001"
pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id[0]

_model = pipe.model
_tokenizer = pipe.tokenizer

In [None]:
quant8_config = BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
quant8_model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quant8_config)
quant8_pipe = pipeline('text-generation', model=quant8_model, tokenizer=_tokenizer, torch_dtype='auto')
# LLama 3.2 has multiple eos_token_id. We use the "128001"
quant8_pipe.tokenizer.pad_token_id = quant8_pipe.model.config.eos_token_id[0]

In [None]:
quant4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type='nf4')
quant4_model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quant4_config)
quant4_pipe = pipeline('text-generation', model=quant4_model, tokenizer=_tokenizer, torch_dtype='auto')
# LLama 3.2 has multiple eos_token_id. We use the "128001"
quant4_pipe.tokenizer.pad_token_id = quant4_pipe.model.config.eos_token_id[0]

In [None]:
llama_1b = sum(param.numel() * param.element_size() for param in _model.parameters()) / (1024 ** 2)
llama_1b_int8 = sum(param.numel() * param.element_size() for param in quant8_model.parameters()) /  (1024 ** 2)
llama_1b_int4 = sum(param.numel() * param.element_size() for param in quant4_model.parameters()) /  (1024 ** 2)

print(f'LLama 3.2 1B Instruct with bfloat16 uses {round(llama_1b, 2)}MB of memory')
print(f'LLama 3.2 1B Instruct with bfloat16 and int8 uses {round(llama_1b_int8, 2)}MB of memory')
print(f'LLama 3.2 1B Instruct with bfloat16 and int4 uses {round(llama_1b_int4, 2)}MB of memory')

### CoQA Dataset without conversations

In [None]:
def get_prompt_for_coqa_question(context: str, question: str) -> list:
    """
    Get prompt for CoQA in chat format. This includes a system and an user prompt.
    """
    return [
        {'role': 'system', 'content': f'You are a chatbot which answers user question in extremely concise manner possible from given context, "{context}".'},
        {'role': 'user', 'content': question}
    ]

In [None]:
coqa_dataset = load_dataset('stanfordnlp/coqa', split='all')
coqa_100 = coqa_dataset.select(range(100))
coqa_100 = coqa_100.add_column('question_prompt', column=[get_prompt_for_coqa_question(context, questions[0])
                                                          for context, questions in zip(coqa_100['story'], coqa_100['questions'])])

coqa_100

In [None]:
coqa_100['question_prompt'][0]

In [None]:
SAMPLING = 50

# Select as per the GPU?
MULTIPLIER = 30


pipe_kwargs = {
                'top_k': 0,
                'top_p': 0.9, # Nucleus sampling: cumulative probability threshold
                'max_new_tokens': 128,
                'pad_token_id': _tokenizer.pad_token_id,
                'batch_size': 5 * MULTIPLIER
            }

del MULTIPLIER

# Originally, it's right side, but huggingface throws warning.
# "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer."
_tokenizer.padding_side = 'left'


generations = {
    'bfloat16': [],
    'int8': [],
    'int4': []
}


inference_ds = KeyDataset(coqa_100.repeat(SAMPLING), key='question_prompt')
n = len(inference_ds)

with tqdm(total=n, desc='bfloat16') as pbar:
    for out in pipe(inference_ds, **pipe_kwargs):
        generations['bfloat16'].append(out[0]['generated_text'])

        pbar.update()

with tqdm(total=n, desc='int8') as pbar:
    for out in quant8_pipe(inference_ds, **pipe_kwargs):
        generations['int8'].append(out[0]['generated_text'])

        pbar.update()

with tqdm(total=n, desc='int4') as pbar:
    for out in quant4_pipe(inference_ds, **pipe_kwargs):
        generations['int4'].append(out[0]['generated_text'])

        pbar.update()

torch.cuda.empty_cache()

In [None]:
def get_prompt_for_verification(questions: str, answers: dict, chat_history: list, context: str) -> list:
    """
    Get prompt in chat format. This includes a system and an user prompt.
    """
    question = questions[0]
    # print(answers.keys())
    # answer = list(answers.values())[0]
    answer = answers['input_text'][0]
    response = chat_history[-1]['content']

    return [
        {'role': 'system', 'content': 'For the following query is the response correct reply with True or False, nothing more. Look at the context or the answer'}, # , nothing more
        {'role': 'user', 'content': f'From the context "{context}" the question "{question}", has the correct answer as "{answer}". Does the response "{response}" fits the correct answer?'}
    ]

In [None]:
coqa_100_generations = coqa_100.repeat(SAMPLING)

for key, value in generations.items():
    name = f'{key}_response'
    coqa_100_generations = coqa_100_generations.add_column(name, value)
    coqa_100_generations = coqa_100_generations.add_column(f'{key}_verification_prompt', column=list(map(get_prompt_for_verification,
                                                                                                       coqa_100_generations['questions'],
                                                                                                       coqa_100_generations['answers'],
                                                                                                       coqa_100_generations[name],
                                                                                                       coqa_100_generations['story'])))

coqa_100_generations

## Verify using Judge model

In [None]:
judge_model_id = "meta-llama/Llama-3.2-3B-Instruct"

judge_pipe = pipeline('text-generation', model=judge_model_id, torch_dtype=torch.bfloat16)
# LLama 3.2 has multiple eos_token_id. We use the "128001"
judge_pipe.tokenizer.pad_token_id = judge_pipe.model.config.eos_token_id[0]

In [None]:
verification_pipe_kwargs = {
    'pad_token_id': judge_pipe.tokenizer.pad_token_id,
    'batch_size': 50
}

judge_pipe.tokenizer.padding_side = 'left'
VERIFICATION_SAMPLES = 25

verifications = {
    'bfloat16': [],
    'int8': [],
    'int4': []
}

for key in verifications.keys():
    verification_ds = KeyDataset(coqa_100_generations.repeat(VERIFICATION_SAMPLES), key=f'{key}_verification_prompt')
    n = len(verification_ds)

    with tqdm(total=n, desc=f'{key}_verification') as pbar:
        for out in judge_pipe(verification_ds, **verification_pipe_kwargs):
            verifications[key].append(out[0]['generated_text'])

            pbar.update()

    torch.cuda.empty_cache()

# out = judge_pipe([coqa_100_generations['bfloat16_verification_prompt'][3]] * 25, **verification_pipe_kwargs)
# # out = judge_pipe(coqa_100_generations['bfloat16_verification_prompt'][0], **verification_pipe_kwargs)

# torch.cuda.empty_cache()
# out

In [None]:
coqa_100_verifications_25 = coqa_100_generations.repeat(VERIFICATION_SAMPLES)

for key, value in verifications.items():
    name = f'{key}_verification_response'
    coqa_100_verifications_25 = coqa_100_verifications_25.add_column(name, value)

coqa_100_generations.save_to_disk('coqa_100_slice')
coqa_100_verifications_25.save_to_disk('coqa_100_verifications_25_slice')

In [None]:
coqa_100_verifications_25 = load_from_disk('coqa_100_verifications_25_slice')

In [None]:
VERIFICATION_SAMPLES = 25
SAMPLING = 50

In [None]:
for verification_key in ('bfloat16', 'int8', 'int4'):
    verification_key = f'{verification_key}_verification_response'
    arr = np.array([int('true' in chat[-1]['content'].lower()) for chat in coqa_100_verifications_25[verification_key]])

    # The first two dims are ordered following FIFO strategy to
    # balance the 'tile' operations applied on the original 100 questions.
    arr = arr.reshape(VERIFICATION_SAMPLES, SAMPLING, -1)
    # We average the results over each verification for each sample.
    mean_accuracy_per_sample : np.ndarray = arr.mean(axis=0)
    # We then average and std over each sample for each question.
    accuracy_per_question : np.ndarray = mean_accuracy_per_sample.mean(axis=0)
    std_per_question : np.ndarray = mean_accuracy_per_sample.std(axis=0)

    mean_accuracy = accuracy_per_question.mean(axis=0)
    avg_std = std_per_question.mean(axis=0)

    print((mean_accuracy * 100).round(3), (avg_std * 100).round(3), (arr * 100).mean().round(3), (arr * 100).std().round(3))

In [None]:
coqa_100_generations[3]['answers']['input_text'][0]

In [None]:
coqa_100_generations[3]['bfloat16_response']

In [None]:
correct_counts = np.array([int('true' in chat[0]['generated_text'][-1]['content'].lower()) for chat in out])
print(correct_counts.mean())
print(correct_counts.std())