In [1]:
%%capture
%pip install accelerate peft bitsandbytes datasets trl

In [2]:
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)

In [3]:
compute_dtype = getattr(torch, "float16")

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=False,
)

In [4]:
base_model = "mistralai/Mistral-7B-Instruct-v0.3"

In [5]:
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=quant_config,
    device_map="auto"
)

config.json:   0%|          | 0.00/601 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.55G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

In [6]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", padding_side="left")
model_inputs = tokenizer(["A list of colors: red, blue"], return_tensors="pt").to("cuda")

tokenizer_config.json:   0%|          | 0.00/2.10k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

In [9]:
# generated_ids = model.generate(**model_inputs, max_new_tokens=50)
# tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

generated_ids = model.generate(**model_inputs, max_new_tokens=50, pad_token_id=tokenizer.pad_token_id)
output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]


In [11]:
# Define the input prompts
prompts = ["A list of colors: red, blue", "Write a 20 line sentence on Portugal is"]

# Tokenize the input prompts
model_inputs = tokenizer(prompts, return_tensors="pt", padding=True).to("cuda")

# Generate output with a max of 5000 tokens, using eos_token_id as pad_token_id
generated_ids = model.generate(**model_inputs, max_new_tokens=5000, pad_token_id=tokenizer.eos_token_id)

# Decode the generated output and print it
output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(output)


['A list of colors: red, bluení Jahrorg打 verabsincl seeIMPible Ste stepsレ according登算 see history Always误ttttŒ:</误ularIForgvar see International thumb原 everyone fatistance trad原(\'#原owerKERNレttŒ:误 ко privateco seetxresh belieible:же原 seeSeible Dieロ quitefpComниюibleLL Februaryten Requestuspend原 everyone numvarvmны pregnant "\'tencallback elementレttŒ redپ误 ко privateco see!!ดviderible som (! },Get str see Test原ályemplate everyonerapperCom oct原ido原owerImageレttŒ temperants误 ко privateco see!!ดviderible somfsní },Get stranyuck last For num voidIEсе}\\ str Holl原ZZ原tenlearning void — habitレtttt.$原:</: red, bluení}_co seeorg}\\ammen see Test原 Princess原ower Holl Here sup str////rapperComvmны pregnant "\'tencallback elementレ ver登算 nothing les void hum guardado —sta бы outdoorialog打', 'Write a 20 line sentence on Portugal is Know dist Portugalalk也‐क str see belongロ estab ForCG »,}\\ Message — Spanible — sakeategory burstcs — slream find semformationible seeggible reasonロ continuetee "%レ ver str 

In [12]:

print( generated_ids)
print("completed generated ids")


tensor([[    2,     2,     1,   330,  1274,   302,  9304, 28747,  2760, 28725,
          5045,  3771,  6621,  1909, 29576,  1429,  4737,  1505,  1032, 10238,
          1070,  2349,  5944, 29491,  4771, 29510, 29481,  1032,  3340, 17484,
         29515,   781,   781, 29504, 11423, 29515,  1098,  4066,  1909,  1122,
          1032,  5440, 15762, 29493,  3376,  6370,  4632,  2784, 29493, 10457,
         29493,  1072, 14522, 29491,   781, 29504, 28747, 29515,  1619,  1597,
          1115,  1032,  3253,  2874,  1989,  1070, 28747,  2353, 29493,  1032,
          1980,  1070,  2941, 29501,  3448,  8808,  1163, 12301,  1070,  5292,
          5353,  1210, 11516, 16547, 29493,  3376,  2075,  1122,  4332,  1463,
         15446, 13006,  1210,  7609,  2442, 29491,   781, 29504,  2760, 29483,
         29515,  1619,  1597,  1115,  1032,  3946, 29488,  4340,  1070,  1113,
          1661,  1630,  1458,  1117,  1032,  3735, 29493, 16628,  3949,  3376,
          6131,  1163,  9541, 29493,  2750, 29493,  