In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
base_model_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
cache_dir = "/scratch/tathagato"


In [3]:
nf4_config = BitsAndBytesConfig(
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_use_double_quant= True,
    bnb_4bit_compute_dtype=torch.bfloat16
)
model_kwargs = dict(
    use_cache=False,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map='cuda:0',
    cache_dir = cache_dir,
    attn_implementation = "eager",
    quantization_config = nf4_config, 
)

print("loading quantized model")
model = AutoModelForCausalLM.from_pretrained(base_model_path, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(base_model_path,cache_dir = cache_dir)

loading quantized model


In [20]:
messages = [
    {
        "role": "system",
        "content": "You are a friendly chatbot who always responds in the style of a pirate",
    },
    {"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
 ]

text = tokenizer.apply_chat_template(messages, tokenize= False)
tokenizer.pad_token = tokenizer.unk_token
#add a function here

outputs = tokenizer(
                text,
                add_special_tokens=True,
                truncation=True,
                padding='max_length',
                max_length=2048,
                return_overflowing_tokens=False,
                return_length=False,
            )
print(len(outputs["input_ids"]))
print(tokenizer.decode(outputs["input_ids"], skip_special_tokens=True))
      

2048
<|system|>
You are a friendly chatbot who always responds in the style of a pirate 
<|user|>
How many helicopters can a human eat in one sitting? 



In [22]:
print(outputs.keys())

dict_keys(['input_ids', 'attention_mask'])


In [23]:
example = [1, 529, 29989, 5205, 29989, 29958, 13, 3492, 526, 263, 19780, 13563, 7451, 1058, 2337, 1371, 278, 1404, 2, 29871, 13, 29966, 29989, 1792, 29989, 29958, 13, 6113, 263, 15837, 310, 278, 1426, 29889, 450, 15837, 881, 367, 4226, 297, 3309, 636, 450, 1881, 1426, 338, 2183, 2400, 29871, 13, 29898, 29907, 10262, 29897, 3629, 261, 322, 14890, 907, 314, 29889, 739, 1838, 29915, 29873, 3721, 6709, 304, 3458, 746, 366, 1348, 310, 22037, 9687, 5101, 886, 1192, 2030, 7875, 1316, 408, 9922, 535, 322, 29808, 470, 1886, 557, 322, 7776, 824, 300, 29889, 1205, 21137, 29915, 29879, 1570, 9923, 1974, 5826, 556, 29891, 322, 278, 900, 2039, 472, 4111, 669, 23052, 29915, 29879, 526, 3815, 292, 701, 373, 263, 367, 261, 20603, 491, 14890, 907, 314, 1192, 15795, 287, 1559, 314, 295, 3347, 2786, 14890, 907, 314, 29892, 304, 367, 18378, 29889, 376, 4178, 445, 931, 306, 508, 9659, 393, 4111, 669, 23052, 29915, 29879, 322, 1570, 9923, 1974, 526, 11465, 1218, 304, 12020, 3773, 8326, 404, 2820, 5626, 591, 526, 15935, 403, 1048, 29892, 322, 393, 278, 2582, 674, 367, 628, 14803, 1699, 1570, 9923, 1974, 29915, 29879, 15944, 310, 317, 504, 475, 3097, 29892, 23774, 1798, 7214, 29892, 1497, 297, 263, 3229, 29889, 9134, 14582, 505, 263, 4955, 310, 5264, 5039, 1608, 29892, 322, 278, 716, 2060, 674, 367, 694, 1422, 29892, 896, 1827, 29889, 11275, 6507, 1838, 29915, 29873, 1827, 825, 278, 11531, 674, 367, 599, 1048, 29892, 541, 4111, 669, 23052, 29915, 29879, 24260, 12002, 4485, 15133, 15629, 19556, 10837, 2330, 27584, 372, 674, 367, 376, 6574, 627, 1319, 1213, 376, 4806, 29915, 276, 4802, 24909, 310, 1570, 9923, 1974, 5826, 556, 29891, 29892, 1009, 1819, 29892, 322, 1009, 2090, 9257, 29892, 322, 310, 3236, 1009, 367, 261, 1699, 540, 1497, 29889, 376, 4806, 29915, 276, 24173, 363, 278, 11531, 591, 29915, 345, 8906, 4208, 1213, 450, 14582, 674, 7475, 346, 278, 4902, 2678, 445, 1629, 29892, 322, 278, 367, 261, 338, 731, 304, 7124, 528, 295, 1960, 297, 278, 6416, 29889, 1570, 9923, 1974, 322, 4111, 669, 23052, 29915, 29879, 526, 1716, 825, 526, 2000, 376, 29933, 12767, 800, 1699, 263, 2284, 2450, 16610, 491, 278, 2024, 1661, 29899, 771, 9202, 350, 365, 6897, 304, 14582, 393, 5870, 967, 5264, 29892, 29380, 29892, 3633, 3097, 322, 1301, 862, 3819, 20801, 29889, 1570, 9923, 1974, 11286, 15075, 475, 519, 18032, 545, 29892, 23622, 1735, 322, 916, 14511, 5056, 29892, 1550, 4111, 669, 23052, 29915, 29879, 1192, 1286, 263, 11684, 8819, 653, 310, 5534, 378, 3820, 12392, 403, 853, 488, 369, 1192, 27111, 1565, 304, 967, 7251, 407, 347, 16778, 411, 2304, 363, 29380, 14511, 5056, 29892, 6534, 11302, 14231, 29892, 13718, 17193, 322, 901, 29889, 9208, 4098, 29892, 4111, 669, 23052, 29915, 29879, 274, 974, 618, 261, 4111, 315, 14899, 1497, 540, 29915, 29881, 367, 1722, 304, 278, 2969, 310, 263, 1766, 26323, 1648, 29899, 7192, 3880, 14890, 907, 314, 1047, 287, 388, 29892, 9763, 393, 731, 3104, 24909, 633, 433, 911, 29889, 1205, 14610, 368, 363, 367, 261, 24909, 29892, 727, 29915, 29879, 694, 5193, 310, 263, 367, 261, 29899, 29888, 4112, 4395, 14890, 907, 314, 29889, 2216, 3447, 8763, 29889, 2, 29871, 13, 29966, 29989, 465, 22137, 29989, 29958, 13, 8897, 4098, 29892, 4111, 669, 23052, 29915, 29879, 9763, 731, 1766, 26323, 1648, 24909, 633, 433, 911, 29892, 274, 974, 618, 261, 4111, 315, 14899, 13384, 393, 1047, 287, 388, 4111, 669, 23052, 29915, 29879, 1033, 6826, 263, 1766, 26323, 1648, 29899, 7192, 3880, 14890, 907, 314, 29892, 7952, 292, 1565, 304, 967, 7251, 407, 347, 16778, 29889, 2, 29871, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
print(tokenizer.decode(example, skip_special_tokens=True))

<|system|>
You are a friendly chatbot who always help the user 
<|user|>
Write a summary of the text. The summary should be normal in length.. The input text is given below 
(CNN)Beer and ice cream. It doesn't exactly spring to mind when you think of classic food pairings -- old friends such as bacon and eggs or steak and cabernet. But Colorado's New Belgium Brewery and the folks at Ben & Jerry's are teaming up on a beer inspired by ice cream -- salted caramel brownie ice cream, to be precise. "At this time I can confirm that Ben & Jerry's and New Belgium are collaborating to raise awareness around issues we are passionate about, and that the results will be delicious," New Belgium's Director of Sustainability, Jenn Vervier, said in a statement. Both companies have a history of social activism, and the new project will be no different, they say. Their release doesn't say what the campaign will be all about, but Ben & Jerry's Senior Global Marketing Manager Jay Curley promises it will b

: 

In [13]:
#print the special tokens
print(tokenizer.special_tokens_map)
#print all of the token id for the special tokens
print(tokenizer.all_special_ids)


{'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '</s>'}
[1, 2, 0]


In [5]:
messages = [
    {
        "role": "system",
        "content": "You are a friendly chatbot who always responds in the style of a pirate",
    },
    {"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
 ]

text = tokenizer.apply_chat_template(messages, tokenize= False)
print(text)

tokens = tokenizer(text, return_tensors="pt")
print(tokens['input_ids'][0])

#decode with tokenizer with special tokens intact
decoded = tokenizer.decode(tokens["input_ids"][0], skip_special_tokens=False)

for token in decoded.split():
    print(token, tokenizer.convert_tokens_to_ids(token))
#print special tokens
print(tokenizer.special_tokens_map)
print(tokenizer.decode([13], skip_special_tokens=False))

<|system|>
You are a friendly chatbot who always responds in the style of a pirate</s>
<|user|>
How many helicopters can a human eat in one sitting?</s>

tensor([    1,   529, 29989,  5205, 29989, 29958,    13,  3492,   526,   263,
        19780, 13563,  7451,  1058,  2337, 10049, 29879,   297,   278,  3114,
          310,   263, 21625,   403,     2, 29871,    13, 29966, 29989,  1792,
        29989, 29958,    13,  5328,  1784,  1081,   293,   459,  2153,   508,
          263,  5199, 17545,   297,   697, 16246, 29973,     2, 29871,    13])
<s> 1
<|system|> 0
You 3492
are 598
a 29874
friendly 0
chatbot 0
who 15970
always 21936
responds 0
in 262
the 1552
style 3293
of 974
a 29874
pirate</s> 0
<|user|> 0
How 5328
many 13011
helicopters 0
can 3068
a 29874
human 26029
eat 0
in 262
one 650
sitting?</s> 0
{'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '</s>'}




In [6]:
#get the token to txt mapping from the tokenizer
print(tokenizer.convert_ids_to_tokens(tokens["input_ids"][0]))
token_token_id_pair = [(tokenizer.convert_ids_to_tokens([tokens["input_ids"][0][i]]), tokens["input_ids"][0][i]) for i in range(len(tokens["input_ids"][0]))]
print(token_token_id_pair)

['<s>', '▁<', '|', 'system', '|', '>', '<0x0A>', 'You', '▁are', '▁a', '▁friendly', '▁chat', 'bot', '▁who', '▁always', '▁respond', 's', '▁in', '▁the', '▁style', '▁of', '▁a', '▁pir', 'ate', '</s>', '▁', '<0x0A>', '<', '|', 'user', '|', '>', '<0x0A>', 'How', '▁many', '▁hel', 'ic', 'op', 'ters', '▁can', '▁a', '▁human', '▁eat', '▁in', '▁one', '▁sitting', '?', '</s>', '▁', '<0x0A>']
[(['<s>'], tensor(1)), (['▁<'], tensor(529)), (['|'], tensor(29989)), (['system'], tensor(5205)), (['|'], tensor(29989)), (['>'], tensor(29958)), (['<0x0A>'], tensor(13)), (['You'], tensor(3492)), (['▁are'], tensor(526)), (['▁a'], tensor(263)), (['▁friendly'], tensor(19780)), (['▁chat'], tensor(13563)), (['bot'], tensor(7451)), (['▁who'], tensor(1058)), (['▁always'], tensor(2337)), (['▁respond'], tensor(10049)), (['s'], tensor(29879)), (['▁in'], tensor(297)), (['▁the'], tensor(278)), (['▁style'], tensor(3114)), (['▁of'], tensor(310)), (['▁a'], tensor(263)), (['▁pir'], tensor(21625)), (['ate'], tensor(403)), (['</

In [7]:
print(tokens)

{'input_ids': tensor([[    1,   529, 29989,  5205, 29989, 29958,    13,  3492,   526,   263,
         19780, 13563,  7451,  1058,  2337, 10049, 29879,   297,   278,  3114,
           310,   263, 21625,   403,     2, 29871,    13, 29966, 29989,  1792,
         29989, 29958,    13,  5328,  1784,  1081,   293,   459,  2153,   508,
           263,  5199, 17545,   297,   697, 16246, 29973,     2, 29871,    13]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1]])}


In [9]:
#create a batched input where the first 10 words of the token is copied and batched
input_id_1 = tokens["input_ids"][0]
input_id_2 = tokens["input_ids"][0][:10]
input_ids = torch.stack([input_id_1, input_id_2])
print(input_ids)

RuntimeError: stack expects each tensor to be equal size, but got [50] at entry 0 and [10] at entry 1

In [8]:
output = model(**tokens)

In [18]:
print(output.keys())

odict_keys(['logits'])


In [20]:
print(output['logits'].shape)

torch.Size([1, 50, 32000])


In [21]:
#decode the output
decoded_output = tokenizer.decode(output['logits'][0].argmax(dim=-1))

In [22]:
print(decoded_output)

MIT|system|>
</s> can a successful andbot that can gres to a same of a humanate. 
<|user|>
Can can pirmopters does the pir carry in one sitting?</s> 
<
