In [1]:
import json

import numpy as np
import torch
import torch.nn as nn
import bitsandbytes as bnb

import datasets
from datasets import load_dataset, load_from_disk
from datasets.arrow_dataset import Dataset
from evaluate import load

from transformers import pipeline
from transformers import BitsAndBytesConfig
from transformers import AutoModelForCausalLM
from transformers.pipelines.pt_utils import KeyDataset

import accelerate

from tqdm import tqdm
# from tqdm.notebook import tqdm

import matplotlib.pyplot as plt

In [2]:
model_id = "meta-llama/Llama-3.2-1B-Instruct"

pipe = pipeline('text-generation', model=model_id, torch_dtype=torch.bfloat16)
# LLama 3.2 has multiple eos_token_id. We use the "128001"
pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id[0]

_model = pipe.model
_tokenizer = pipe.tokenizer

Device set to use cuda:0


In [3]:
quant8_config = BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
quant8_model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quant8_config)
quant8_pipe = pipeline('text-generation', model=quant8_model, tokenizer=_tokenizer, torch_dtype='auto')
# LLama 3.2 has multiple eos_token_id. We use the "128001"
quant8_pipe.tokenizer.pad_token_id = quant8_pipe.model.config.eos_token_id[0]

`low_cpu_mem_usage` was None, now default to True since model is quantized.
Device set to use cuda:0


In [4]:
quant4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type='nf4')
quant4_model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quant4_config)
quant4_pipe = pipeline('text-generation', model=quant4_model, tokenizer=_tokenizer, torch_dtype='auto')
# LLama 3.2 has multiple eos_token_id. We use the "128001"
quant4_pipe.tokenizer.pad_token_id = quant4_pipe.model.config.eos_token_id[0]

`low_cpu_mem_usage` was None, now default to True since model is quantized.
Device set to use cuda:0


In [5]:
llama_1b = sum(param.numel() * param.element_size() for param in _model.parameters()) / (1024 ** 2)
llama_1b_int8 = sum(param.numel() * param.element_size() for param in quant8_model.parameters()) /  (1024 ** 2)
llama_1b_int4 = sum(param.numel() * param.element_size() for param in quant4_model.parameters()) /  (1024 ** 2)

print(f'LLama 3.2 1B Instruct with bfloat16 uses {round(llama_1b, 2)}MB of memory')
print(f'LLama 3.2 1B Instruct with bfloat16 and int8 uses {round(llama_1b_int8, 2)}MB of memory')
print(f'LLama 3.2 1B Instruct with bfloat16 and int4 uses {round(llama_1b_int4, 2)}MB of memory')

LLama 3.2 1B Instruct with bfloat16 uses 2357.13MB of memory
LLama 3.2 1B Instruct with bfloat16 and int8 uses 1429.13MB of memory
LLama 3.2 1B Instruct with bfloat16 and int4 uses 965.13MB of memory


### CoQA Dataset without conversations

In [6]:
def get_prompt_for_coqa_question(context: str, question: str) -> list:
    """
    Get prompt for CoQA in chat format. This includes a system and an user prompt.
    """
    return [
        {'role': 'system', 'content': f'You are a chatbot which answers user question in extremely concise manner possible from given context, "{context}".'},
        {'role': 'user', 'content': question}
    ]

In [7]:
coqa_dataset = load_dataset('stanfordnlp/coqa', split='all')
coqa_100 = coqa_dataset.select(range(100))
coqa_100 = coqa_100.add_column('question_prompt', column=[get_prompt_for_coqa_question(context, questions[0]) 
                                                          for context, questions in zip(coqa_100['story'], coqa_100['questions'])])

coqa_100

Dataset({
    features: ['source', 'story', 'questions', 'answers', 'question_prompt'],
    num_rows: 100
})

In [8]:
coqa_100['question_prompt'][0]

[{'content': 'You are a chatbot which answers user question in extremely concise manner possible from given context, "The Vatican Apostolic Library (), more commonly called the Vatican Library or simply the Vat, is the library of the Holy See, located in Vatican City. Formally established in 1475, although it is much older, it is one of the oldest libraries in the world and contains one of the most significant collections of historical texts. It has 75,000 codices from throughout history, as well as 1.1 million printed books, which include some 8,500 incunabula. \n\nThe Vatican Library is a research library for history, law, philosophy, science and theology. The Vatican Library is open to anyone who can document their qualifications and research needs. Photocopies for private study of pages from books published between 1801 and 1990 can be requested in person or by mail. \n\nIn March 2014, the Vatican Library began an initial four-year project of digitising its collection of manuscript

In [9]:
# Constrastive search sampling gives only one sample per input, so we sample it multiple times.
# SAMPLING = 50
SAMPLING = 50

# Select as per the GPU?
MULTIPLIER = 30

# params for Constrastive Search Sampling.
pipe_kwargs = {
    'penalty_alpha': 0.6,
    'top_k': 40, # 'top_k': 4,
    'max_new_tokens': 128,
    'pad_token_id': _tokenizer.pad_token_id,
    'batch_size': 5 * MULTIPLIER
}

del MULTIPLIER

# Originally, it's right side, but huggingface throws warning. 
# "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer."
_tokenizer.padding_side = 'left'


generations = {
    'bfloat16': [],
    'int8': [],
    'int4': []
}


inference_ds = KeyDataset(coqa_100.repeat(SAMPLING), key='question_prompt')
n = len(inference_ds)

with tqdm(total=n, desc='bfloat16') as pbar:
    for out in pipe(inference_ds, **pipe_kwargs):
        generations['bfloat16'].append(out[0]['generated_text'])

        pbar.update()

with tqdm(total=n, desc='int8') as pbar:
    for out in quant8_pipe(inference_ds, **pipe_kwargs):
        generations['int8'].append(out[0]['generated_text'])

        pbar.update()

with tqdm(total=n, desc='int4') as pbar:
    for out in quant4_pipe(inference_ds, **pipe_kwargs):
        generations['int4'].append(out[0]['generated_text'])

        pbar.update()

torch.cuda.empty_cache()

bfloat16: 100%|██████████| 5000/5000 [07:19<00:00, 11.37it/s]  
int8: 100%|██████████| 5000/5000 [08:20<00:00, 10.00it/s]  
int4: 100%|██████████| 5000/5000 [07:19<00:00, 11.37it/s]  


In [10]:
def get_prompt_for_verification(questions: str, answers: dict, chat_history: list, context: str) -> list:
    """
    Get prompt in chat format. This includes a system and an user prompt.
    """
    question = questions[0]
    # print(answers.keys())
    # answer = list(answers.values())[0]
    answer = answers['input_text'][0]
    response = chat_history[-1]['content']
    
    return [
        {'role': 'system', 'content': 'For the following query is the response correct reply with True or False, nothing more. Look at the context or the answer'}, # , nothing more
        {'role': 'user', 'content': f'From the context "{context}" the question "{question}", has the correct answer as "{answer}". Does the response "{response}" fits the correct answer?'}
    ]

In [11]:
coqa_100_generations = coqa_100.repeat(SAMPLING)

for key, value in generations.items():
    name = f'{key}_response'
    coqa_100_generations = coqa_100_generations.add_column(name, value)
    coqa_100_generations = coqa_100_generations.add_column(f'{key}_verification_prompt', column=list(map(get_prompt_for_verification, 
                                                                                                       coqa_100_generations['questions'], 
                                                                                                       coqa_100_generations['answers'],
                                                                                                       coqa_100_generations[name],
                                                                                                       coqa_100_generations['story'])))

coqa_100_generations

Dataset({
    features: ['source', 'story', 'questions', 'answers', 'question_prompt', 'bfloat16_response', 'bfloat16_verification_prompt', 'int8_response', 'int8_verification_prompt', 'int4_response', 'int4_verification_prompt'],
    num_rows: 5000
})

## Verify using Judge model

In [12]:
judge_model_id = "meta-llama/Llama-3.2-3B-Instruct"

judge_pipe = pipeline('text-generation', model=judge_model_id, torch_dtype=torch.bfloat16)
# LLama 3.2 has multiple eos_token_id. We use the "128001"
judge_pipe.tokenizer.pad_token_id = judge_pipe.model.config.eos_token_id[0]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0


In [14]:
verification_pipe_kwargs = {
    'pad_token_id': judge_pipe.tokenizer.pad_token_id,
    'batch_size': 50
}

judge_pipe.tokenizer.padding_side = 'left'
VERIFICATION_SAMPLES = 25

verifications = {
    'bfloat16': [],
    'int8': [],
    'int4': []
}

for key in verifications.keys():
    verification_ds = KeyDataset(coqa_100_generations.repeat(VERIFICATION_SAMPLES), key=f'{key}_verification_prompt')
    n = len(verification_ds)

    with tqdm(total=n, desc=f'{key}_verification') as pbar:
        for out in judge_pipe(verification_ds, **verification_pipe_kwargs):
            verifications[key].append(out[0]['generated_text'])

            pbar.update()

    torch.cuda.empty_cache()

# out = judge_pipe([coqa_100_generations['bfloat16_verification_prompt'][3]] * 25, **verification_pipe_kwargs)
# # out = judge_pipe(coqa_100_generations['bfloat16_verification_prompt'][0], **verification_pipe_kwargs)

# torch.cuda.empty_cache()
# out

bfloat16_verification: 100%|██████████| 125000/125000 [1:38:03<00:00, 21.25it/s]    
int8_verification: 100%|██████████| 125000/125000 [1:39:25<00:00, 20.95it/s]    
int4_verification: 100%|██████████| 125000/125000 [1:40:24<00:00, 20.75it/s]    


In [15]:
coqa_100_verifications_25 = coqa_100_generations.repeat(VERIFICATION_SAMPLES)

for key, value in verifications.items():
    name = f'{key}_verification_response'
    coqa_100_verifications_25 = coqa_100_verifications_25.add_column(name, value)

coqa_100_generations.save_to_disk('coqa_100_slice')
coqa_100_verifications_25.save_to_disk('coqa_100_verifications_25_slice')

Saving the dataset (0/1 shards):   0%|          | 0/5000 [00:00<?, ? examples/s]

Saving the dataset (0/6 shards):   0%|          | 0/125000 [00:00<?, ? examples/s]

In [3]:
coqa_100_verifications_25 = load_from_disk('coqa_100_verifications_25_slice')

In [5]:
VERIFICATION_SAMPLES = 25
SAMPLING = 50

In [8]:
for verification_key in ('bfloat16', 'int8', 'int4'): 
    verification_key = f'{verification_key}_verification_response'
    arr = np.array([int('true' in chat[-1]['content'].lower()) for chat in coqa_100_verifications_25[verification_key]])
    
    # The first two dims are ordered following FIFO strategy to 
    # balance the 'tile' operations applied on the original 100 questions.
    arr = arr.reshape(VERIFICATION_SAMPLES, SAMPLING, -1)
    # We average the results over each verification for each sample.
    mean_accuracy_per_sample : np.ndarray = arr.mean(axis=0)
    # We then average and std over each sample for each question.
    accuracy_per_question : np.ndarray = mean_accuracy_per_sample.mean(axis=0)
    std_per_question : np.ndarray = mean_accuracy_per_sample.std(axis=0)

    mean_accuracy = accuracy_per_question.mean(axis=0)
    avg_std = std_per_question.mean(axis=0)

    print((mean_accuracy * 100).round(3), (avg_std * 100).round(3), (arr * 100).mean().round(3), (arr * 100).std().round(3))

23.401 13.341 23.401 42.338
24.228 13.397 24.228 42.846
22.047 13.489 22.047 41.456


In [14]:
coqa_100_generations[3]['answers']['input_text'][0]

'Donner'

In [15]:
coqa_100_generations[3]['bfloat16_response']

[{'content': 'You are a chatbot which answers user question in extremely concise manner possible from given context, "(CNN) -- The longest-running holiday special still has a very shiny nose. \n\n"Rudolph the Red-Nosed Reindeer" premiered on television December 6, 1964, and is now one of the holiday season\'s perennial favorites. The story of the reindeer who saves Christmas is beloved among children and adults alike. \n\nThe Rankin-Bass animated film production company used Japanese puppets and stop motion to tell the tale, bolstered by a soundtrack featuring Burl Ives\' rendition of the theme song. \n\nIn the story, Santa\'s reindeer Donner and his wife have a son, Rudolph, who has the distinction of a nose that glows. He runs away after being made to feel an outcast and links up with an elf who dreams of becoming a dentist and an adventurer seeking silver and gold. \n\nAfter ending up on the Island of Misfit Toys and wandering for a while, Rudolph goes on to save his loved ones from

In [16]:
out

[[{'generated_text': [{'content': 'For the following query is the response correct reply with True or False, nothing more. Look at the context or the answer',
     'role': 'system'},
    {'content': 'From the context "(CNN) -- The longest-running holiday special still has a very shiny nose. \n\n"Rudolph the Red-Nosed Reindeer" premiered on television December 6, 1964, and is now one of the holiday season\'s perennial favorites. The story of the reindeer who saves Christmas is beloved among children and adults alike. \n\nThe Rankin-Bass animated film production company used Japanese puppets and stop motion to tell the tale, bolstered by a soundtrack featuring Burl Ives\' rendition of the theme song. \n\nIn the story, Santa\'s reindeer Donner and his wife have a son, Rudolph, who has the distinction of a nose that glows. He runs away after being made to feel an outcast and links up with an elf who dreams of becoming a dentist and an adventurer seeking silver and gold. \n\nAfter ending up

In [17]:
correct_counts = np.array([int('true' in chat[0]['generated_text'][-1]['content'].lower()) for chat in out])
print(correct_counts.mean())
print(correct_counts.std())

1.0
0.0


In [183]:
[coqa_100_generations['bfloat16_verification_prompt'][0]] * 2

[[{'content': 'For the following query give response as True or False, nothing more.',
   'role': 'system'},
  {'content': 'From the context "The Vatican Apostolic Library (), more commonly called the Vatican Library or simply the Vat, is the library of the Holy See, located in Vatican City. Formally established in 1475, although it is much older, it is one of the oldest libraries in the world and contains one of the most significant collections of historical texts. It has 75,000 codices from throughout history, as well as 1.1 million printed books, which include some 8,500 incunabula. \n\nThe Vatican Library is a research library for history, law, philosophy, science and theology. The Vatican Library is open to anyone who can document their qualifications and research needs. Photocopies for private study of pages from books published between 1801 and 1990 can be requested in person or by mail. \n\nIn March 2014, the Vatican Library began an initial four-year project of digitising its 

In [171]:
out[0][0]['generated_text'][-1]['content']

'False'

In [None]:
accuracy = 

[[{'generated_text': [{'content': 'For the following query give response as True or False, nothing more.',
     'role': 'system'},
    {'content': 'From the context "The Vatican Apostolic Library (), more commonly called the Vatican Library or simply the Vat, is the library of the Holy See, located in Vatican City. Formally established in 1475, although it is much older, it is one of the oldest libraries in the world and contains one of the most significant collections of historical texts. It has 75,000 codices from throughout history, as well as 1.1 million printed books, which include some 8,500 incunabula. \n\nThe Vatican Library is a research library for history, law, philosophy, science and theology. The Vatican Library is open to anyone who can document their qualifications and research needs. Photocopies for private study of pages from books published between 1801 and 1990 can be requested in person or by mail. \n\nIn March 2014, the Vatican Library began an initial four-year pr

In [165]:
[coqa_100_generations['bfloat16_verification_prompt'][0]] * 2

[[{'content': 'For the following query give response as True or False, nothing more.',
   'role': 'system'},
  {'content': 'From the context "The Vatican Apostolic Library (), more commonly called the Vatican Library or simply the Vat, is the library of the Holy See, located in Vatican City. Formally established in 1475, although it is much older, it is one of the oldest libraries in the world and contains one of the most significant collections of historical texts. It has 75,000 codices from throughout history, as well as 1.1 million printed books, which include some 8,500 incunabula. \n\nThe Vatican Library is a research library for history, law, philosophy, science and theology. The Vatican Library is open to anyone who can document their qualifications and research needs. Photocopies for private study of pages from books published between 1801 and 1990 can be requested in person or by mail. \n\nIn March 2014, the Vatican Library began an initial four-year project of digitising its 