<a href="https://colab.research.google.com/github/sdossou/LoRA_QLoRA/blob/main/%22Instruction_Tuning%22_Mistral_7B_Instruct_with_LoRA_vDolly.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Instruct-tuning Mistral-7B-Instruct-v0.2 using `peft`, `transformers` and `bitsandbytes`

This notebook  fine-tunes Mistral-7B-Instruct-v0.2 a large language model.

This will retrain the weights of the model using LoRA or Low-Rank Adaptation.

It uses the databricks/databricks-dolly-15k to fine-tune Mistral-7B-Instruct-v0.2 to be able to generate instructions.

### Dependencies

Installing relevant dependencies.

In [None]:
!pip install -qU bitsandbytes datasets accelerate loralib transformers peft

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m102.2/102.2 MB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m51.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m290.1/290.1 kB[0m [31m35.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.8/8.8 MB[0m [31m104.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.1/199.1 kB[0m [31m27.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m26.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━

### Model loading
Loading the `mistralai/Mistral-7B-Instruct-v0.2` model and tokenizer.




In [None]:
import torch
torch.cuda.is_available()

True

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "mistralai/Mistral-7B-Instruct-v0.2"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    use_cache=False,
    device_map="auto"
)

config.json:   0%|          | 0.00/596 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

Loading the tokenizer with the padding token.

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

tokenizer_config.json:   0%|          | 0.00/1.46k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

#### Model Architecture

Applying LoRA to the modules related to the attention weights: `q_proj`, `v_proj`, `query_key_value`. This is model dependent

In [None]:
print(base_model)

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )

### Apply LoRA

Loading the `PeftModel` and specifying using low-rank adapters (LoRA) using `get_peft_model` utility function from `peft`.

#### Helper Function to Print Parameter Percentage

This is just a helper function to print out just how much LoRA reduces the number of trainable parameters.

In [None]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

#### Initialising LoRA Config
Main parameters:

- `r`: is the "rank" of the two decomposed matrices used to represent the weight matrix. This is the dimension of the decomposed matrices.

The following is an exerpt from the paper to help provide context for the selected `r`

- `target_modules`: As LoRA can be applied to *any* weight matrix, *which* module (weight matrix) it is being applied to needs to be configured. While the LoRA paper suggests applying it only to the Attention weights, this notebook follows the guidance of the QLoRA paper by applying LoRA to all Linear layers.


- `task_type`: This is a derived property. If you're using a causal model, this should be set to `CAUSAL_LM`. Ensure this property is set based on the selected model.



In [None]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

lora_config = LoraConfig(
    r=64,
    lora_alpha=128,
    #target_modules=["q_proj", "v_proj", "k_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

base_model = prepare_model_for_kbit_training(base_model)
model = get_peft_model(base_model, lora_config)
print_trainable_parameters(model)

trainable params: 27262976 || all params: 3779334144 || trainable%: 0.7213698223345028


In [None]:
print(model)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): MistralForCausalLM(
      (model): MistralModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x MistralDecoderLayer(
            (self_attn): MistralSdpaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): Linear4bit(in_features=4096, out_features=1024

### Preprocessing

Loading the dolly dataset.

In [None]:
import transformers
from datasets import load_dataset

dataset_name = "databricks/databricks-dolly-15k"

In [None]:
dataset = load_dataset(dataset_name)
print(dataset)

Downloading readme:   0%|          | 0.00/8.20k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/13.1M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['instruction', 'context', 'response', 'category'],
        num_rows: 15011
    })
})


In [None]:
print(dataset['train'][0])

{'instruction': 'When did Virgin Australia start operating?', 'context': "Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.", 'response': 'Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.', 'category': 'closed_qa'}


In [None]:
print(dataset['train'][5])

{'instruction': 'If I have more pieces at the time of stalemate, have I won?', 'context': 'Stalemate is a situation in chess where the player whose turn it is to move is not in check and has no legal move. Stalemate results in a draw. During the endgame, stalemate is a resource that can enable the player with the inferior position to draw the game rather than lose. In more complex positions, stalemate is much rarer, usually taking the form of a swindle that succeeds only if the superior side is inattentive.[citation needed] Stalemate is also a common theme in endgame studies and other chess problems.\n\nThe outcome of a stalemate was standardized as a draw in the 19th century. Before this standardization, its treatment varied widely, including being deemed a win for the stalemating player, a half-win for that player, or a loss for that player; not being permitted; and resulting in the stalemated player missing a turn. Stalemate rules vary in other games of the chess family.', 'response

Limiting the number of samples to 5K.

In [None]:
dataset_subset = dataset["train"].select(range(5_000))

Putting the data in the following form:

```
Generate a simple instruction an LLM could use to generate the provided context.
[INST]CONTEXT: Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[/INST]
INSTRUCTION: When did Virgin Australia start operating?
```

 Showing the model examples of a prompt and its completion.

In [None]:
def generate_prompt(example, return_response=True) -> str:
  full_prompt = f"Generate a simple instruction that could result in the provided context."
  full_prompt += f"[INST]CONTEXT: {example['response']}[/INST]"

  if return_response:
    full_prompt += f"INSTRUCTION: "
    full_prompt += f"{example['instruction']}"
  return [full_prompt]

In [None]:
generate_prompt(dataset_subset[0])[0]

'Generate a simple instruction that could result in the provided context.[INST]CONTEXT: Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[/INST]INSTRUCTION: When did Virgin Australia start operating?'

The `Trainer` class contains the same hyper-parameters yas traditional ML applications.

If you're running into CUDA memory issues - please modify both the `per_device_train_batch_size` to be lower, and also reduce `r` in the LoRAConfig. You will need to restart and re-run your notebook after doing so.

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="mistral-7b-instruct",
    num_train_epochs=100,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    optim="paged_adamw_32bit", # from the QLoRA paper
    logging_steps=1,
    save_strategy="epoch",
    learning_rate=2e-4,
    bf16=True, # ensure proper upcasting for compute dtypes
    tf32=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    disable_tqdm=True
)

In [None]:
!pip install trl -U -q

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/225.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━[0m [32m215.0/225.0 kB[0m [31m7.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m225.0/225.0 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/79.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.8/79.8 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from trl import SFTTrainer

max_seq_length = 2048

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset_subset,
    peft_config=lora_config,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    formatting_func=generate_prompt,
    args=training_args,
)

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [None]:
trainer.train()



{'loss': 2.2288, 'grad_norm': 5.4804768562316895, 'learning_rate': 0.0002, 'epoch': 1.0}




{'loss': 1.8029, 'grad_norm': 1.2459341287612915, 'learning_rate': 0.0002, 'epoch': 2.0}




{'loss': 1.6533, 'grad_norm': 0.9689121246337891, 'learning_rate': 0.0002, 'epoch': 3.0}




{'loss': 1.534, 'grad_norm': 0.9293830394744873, 'learning_rate': 0.0002, 'epoch': 4.0}




{'loss': 1.6108, 'grad_norm': 1.0113400220870972, 'learning_rate': 0.0002, 'epoch': 5.0}




{'loss': 1.2895, 'grad_norm': 1.0044028759002686, 'learning_rate': 0.0002, 'epoch': 6.0}




{'loss': 1.3485, 'grad_norm': 1.1140828132629395, 'learning_rate': 0.0002, 'epoch': 7.0}




{'loss': 1.2836, 'grad_norm': 1.8770307302474976, 'learning_rate': 0.0002, 'epoch': 8.0}




{'loss': 1.1951, 'grad_norm': 1.8426021337509155, 'learning_rate': 0.0002, 'epoch': 9.0}




{'loss': 0.964, 'grad_norm': 1.9839824438095093, 'learning_rate': 0.0002, 'epoch': 10.0}




{'loss': 0.908, 'grad_norm': 2.443695306777954, 'learning_rate': 0.0002, 'epoch': 11.0}




{'loss': 0.6554, 'grad_norm': 1.8964112997055054, 'learning_rate': 0.0002, 'epoch': 12.0}




{'loss': 0.6601, 'grad_norm': 2.6477794647216797, 'learning_rate': 0.0002, 'epoch': 13.0}




{'loss': 0.4695, 'grad_norm': 2.779981851577759, 'learning_rate': 0.0002, 'epoch': 14.0}




{'loss': 0.3842, 'grad_norm': 1.6091575622558594, 'learning_rate': 0.0002, 'epoch': 15.0}




{'loss': 0.3647, 'grad_norm': 1.6802695989608765, 'learning_rate': 0.0002, 'epoch': 16.0}




{'loss': 0.3095, 'grad_norm': 4.428572654724121, 'learning_rate': 0.0002, 'epoch': 17.0}




{'loss': 0.3701, 'grad_norm': 2.102550506591797, 'learning_rate': 0.0002, 'epoch': 18.0}




{'loss': 0.1859, 'grad_norm': 1.4863488674163818, 'learning_rate': 0.0002, 'epoch': 19.0}




{'loss': 0.1385, 'grad_norm': 1.7573186159133911, 'learning_rate': 0.0002, 'epoch': 20.0}




{'loss': 0.1859, 'grad_norm': 2.493208885192871, 'learning_rate': 0.0002, 'epoch': 21.0}




{'loss': 0.1098, 'grad_norm': 1.2694973945617676, 'learning_rate': 0.0002, 'epoch': 22.0}




{'loss': 0.0755, 'grad_norm': 1.2623099088668823, 'learning_rate': 0.0002, 'epoch': 23.0}




{'loss': 0.0629, 'grad_norm': 0.8111180067062378, 'learning_rate': 0.0002, 'epoch': 24.0}




{'loss': 0.0524, 'grad_norm': 0.9838863611221313, 'learning_rate': 0.0002, 'epoch': 25.0}




{'loss': 0.0534, 'grad_norm': 0.8431493043899536, 'learning_rate': 0.0002, 'epoch': 26.0}




{'loss': 0.0301, 'grad_norm': 1.1564695835113525, 'learning_rate': 0.0002, 'epoch': 27.0}




{'loss': 0.044, 'grad_norm': 0.9574570059776306, 'learning_rate': 0.0002, 'epoch': 28.0}




{'loss': 0.0236, 'grad_norm': 0.9502139687538147, 'learning_rate': 0.0002, 'epoch': 29.0}




{'loss': 0.0185, 'grad_norm': 0.6847431659698486, 'learning_rate': 0.0002, 'epoch': 30.0}




{'loss': 0.0181, 'grad_norm': 0.5686737895011902, 'learning_rate': 0.0002, 'epoch': 31.0}




{'loss': 0.0139, 'grad_norm': 0.41077855229377747, 'learning_rate': 0.0002, 'epoch': 32.0}




{'loss': 0.0152, 'grad_norm': 1.0139888525009155, 'learning_rate': 0.0002, 'epoch': 33.0}




{'loss': 0.0127, 'grad_norm': 0.7188061475753784, 'learning_rate': 0.0002, 'epoch': 34.0}




{'loss': 0.0143, 'grad_norm': 0.7396520376205444, 'learning_rate': 0.0002, 'epoch': 35.0}




{'loss': 0.0118, 'grad_norm': 0.3488834500312805, 'learning_rate': 0.0002, 'epoch': 36.0}




{'loss': 0.0104, 'grad_norm': 0.32549235224723816, 'learning_rate': 0.0002, 'epoch': 37.0}




{'loss': 0.011, 'grad_norm': 1.486786127090454, 'learning_rate': 0.0002, 'epoch': 38.0}




{'loss': 0.0149, 'grad_norm': 1.3907114267349243, 'learning_rate': 0.0002, 'epoch': 39.0}




{'loss': 0.0152, 'grad_norm': 1.0777249336242676, 'learning_rate': 0.0002, 'epoch': 40.0}




{'loss': 0.0112, 'grad_norm': 0.43155691027641296, 'learning_rate': 0.0002, 'epoch': 41.0}




{'loss': 0.0098, 'grad_norm': 0.43113428354263306, 'learning_rate': 0.0002, 'epoch': 42.0}




{'loss': 0.0094, 'grad_norm': 0.4887959659099579, 'learning_rate': 0.0002, 'epoch': 43.0}




{'loss': 0.0102, 'grad_norm': 0.49672093987464905, 'learning_rate': 0.0002, 'epoch': 44.0}




{'loss': 0.0095, 'grad_norm': 0.8068413734436035, 'learning_rate': 0.0002, 'epoch': 45.0}




{'loss': 0.0101, 'grad_norm': 0.4478374123573303, 'learning_rate': 0.0002, 'epoch': 46.0}




{'loss': 0.009, 'grad_norm': 0.45143476128578186, 'learning_rate': 0.0002, 'epoch': 47.0}




{'loss': 0.0078, 'grad_norm': 0.2782938778400421, 'learning_rate': 0.0002, 'epoch': 48.0}




{'loss': 0.0078, 'grad_norm': 0.23565144836902618, 'learning_rate': 0.0002, 'epoch': 49.0}




{'loss': 0.0081, 'grad_norm': 0.3519022762775421, 'learning_rate': 0.0002, 'epoch': 50.0}




{'loss': 0.0078, 'grad_norm': 0.46870601177215576, 'learning_rate': 0.0002, 'epoch': 51.0}




{'loss': 0.0081, 'grad_norm': 0.19682571291923523, 'learning_rate': 0.0002, 'epoch': 52.0}




{'loss': 0.0077, 'grad_norm': 0.5014159679412842, 'learning_rate': 0.0002, 'epoch': 53.0}




{'loss': 0.0085, 'grad_norm': 0.6303669214248657, 'learning_rate': 0.0002, 'epoch': 54.0}




{'loss': 0.0115, 'grad_norm': 0.8170263767242432, 'learning_rate': 0.0002, 'epoch': 55.0}




{'loss': 0.009, 'grad_norm': 0.4157545268535614, 'learning_rate': 0.0002, 'epoch': 56.0}




{'loss': 0.0082, 'grad_norm': 0.49302032589912415, 'learning_rate': 0.0002, 'epoch': 57.0}




{'loss': 0.0077, 'grad_norm': 0.1642012894153595, 'learning_rate': 0.0002, 'epoch': 58.0}




{'loss': 0.0079, 'grad_norm': 0.3256623148918152, 'learning_rate': 0.0002, 'epoch': 59.0}




{'loss': 0.0077, 'grad_norm': 0.2503563463687897, 'learning_rate': 0.0002, 'epoch': 60.0}




{'loss': 0.0071, 'grad_norm': 0.15481582283973694, 'learning_rate': 0.0002, 'epoch': 61.0}




{'loss': 0.007, 'grad_norm': 0.24927616119384766, 'learning_rate': 0.0002, 'epoch': 62.0}




{'loss': 0.0069, 'grad_norm': 0.19936080276966095, 'learning_rate': 0.0002, 'epoch': 63.0}




{'loss': 0.0072, 'grad_norm': 0.27008363604545593, 'learning_rate': 0.0002, 'epoch': 64.0}




{'loss': 0.008, 'grad_norm': 0.43821918964385986, 'learning_rate': 0.0002, 'epoch': 65.0}




{'loss': 0.0071, 'grad_norm': 0.5714155435562134, 'learning_rate': 0.0002, 'epoch': 66.0}




{'loss': 0.0079, 'grad_norm': 0.3054104447364807, 'learning_rate': 0.0002, 'epoch': 67.0}




{'loss': 0.0078, 'grad_norm': 0.2790951430797577, 'learning_rate': 0.0002, 'epoch': 68.0}




{'loss': 0.0068, 'grad_norm': 0.17515155673027039, 'learning_rate': 0.0002, 'epoch': 69.0}




{'loss': 0.0074, 'grad_norm': 0.3287735879421234, 'learning_rate': 0.0002, 'epoch': 70.0}




{'loss': 0.0072, 'grad_norm': 0.30873391032218933, 'learning_rate': 0.0002, 'epoch': 71.0}




{'loss': 0.0074, 'grad_norm': 0.39411771297454834, 'learning_rate': 0.0002, 'epoch': 72.0}




{'loss': 0.0068, 'grad_norm': 0.12658937275409698, 'learning_rate': 0.0002, 'epoch': 73.0}




{'loss': 0.0068, 'grad_norm': 0.16186745464801788, 'learning_rate': 0.0002, 'epoch': 74.0}




{'loss': 0.0069, 'grad_norm': 0.25508764386177063, 'learning_rate': 0.0002, 'epoch': 75.0}




{'loss': 0.0074, 'grad_norm': 0.6734135746955872, 'learning_rate': 0.0002, 'epoch': 76.0}




{'loss': 0.0092, 'grad_norm': 0.5721321105957031, 'learning_rate': 0.0002, 'epoch': 77.0}




{'loss': 0.0077, 'grad_norm': 0.5033586621284485, 'learning_rate': 0.0002, 'epoch': 78.0}




{'loss': 0.007, 'grad_norm': 0.30522647500038147, 'learning_rate': 0.0002, 'epoch': 79.0}




{'loss': 0.0066, 'grad_norm': 0.27377960085868835, 'learning_rate': 0.0002, 'epoch': 80.0}




{'loss': 0.0063, 'grad_norm': 0.1430925726890564, 'learning_rate': 0.0002, 'epoch': 81.0}




{'loss': 0.0065, 'grad_norm': 0.14892429113388062, 'learning_rate': 0.0002, 'epoch': 82.0}




{'loss': 0.0074, 'grad_norm': 0.6323414444923401, 'learning_rate': 0.0002, 'epoch': 83.0}




{'loss': 0.007, 'grad_norm': 0.26101580262184143, 'learning_rate': 0.0002, 'epoch': 84.0}




{'loss': 0.0062, 'grad_norm': 0.11067664623260498, 'learning_rate': 0.0002, 'epoch': 85.0}




{'loss': 0.0125, 'grad_norm': 0.5918269157409668, 'learning_rate': 0.0002, 'epoch': 86.0}




{'loss': 0.0071, 'grad_norm': 0.2367808222770691, 'learning_rate': 0.0002, 'epoch': 87.0}




{'loss': 0.0063, 'grad_norm': 0.25942203402519226, 'learning_rate': 0.0002, 'epoch': 88.0}




{'loss': 0.006, 'grad_norm': 0.10238330811262131, 'learning_rate': 0.0002, 'epoch': 89.0}




{'loss': 0.0071, 'grad_norm': 0.42621949315071106, 'learning_rate': 0.0002, 'epoch': 90.0}




{'loss': 0.0074, 'grad_norm': 0.42982426285743713, 'learning_rate': 0.0002, 'epoch': 91.0}




{'loss': 0.0062, 'grad_norm': 0.17316237092018127, 'learning_rate': 0.0002, 'epoch': 92.0}




{'loss': 0.0094, 'grad_norm': 1.241540551185608, 'learning_rate': 0.0002, 'epoch': 93.0}




{'loss': 0.0061, 'grad_norm': 0.12055987119674683, 'learning_rate': 0.0002, 'epoch': 94.0}




{'loss': 0.0064, 'grad_norm': 0.17949679493904114, 'learning_rate': 0.0002, 'epoch': 95.0}




{'loss': 0.0064, 'grad_norm': 0.37709370255470276, 'learning_rate': 0.0002, 'epoch': 96.0}




{'loss': 0.0062, 'grad_norm': 0.30560192465782166, 'learning_rate': 0.0002, 'epoch': 97.0}




{'loss': 0.006, 'grad_norm': 0.07423694431781769, 'learning_rate': 0.0002, 'epoch': 98.0}




{'loss': 0.0059, 'grad_norm': 0.07556795328855515, 'learning_rate': 0.0002, 'epoch': 99.0}




{'loss': 0.0061, 'grad_norm': 0.12097375839948654, 'learning_rate': 0.0002, 'epoch': 100.0}
{'train_runtime': 484.3011, 'train_samples_per_second': 1.032, 'train_steps_per_second': 0.206, 'train_loss': 0.2061143883317709, 'epoch': 100.0}


TrainOutput(global_step=100, training_loss=0.2061143883317709, metrics={'train_runtime': 484.3011, 'train_samples_per_second': 1.032, 'train_steps_per_second': 0.206, 'train_loss': 0.2061143883317709, 'epoch': 100.0})

In [None]:
trainer.save_model()

In [None]:
from peft import AutoPeftModelForCausalLM

model = AutoPeftModelForCausalLM.from_pretrained(
    training_args.output_dir,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
)
tokenizer = AutoTokenizer.from_pretrained(training_args.output_dir)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
sample = dataset_subset[5]

prompt = generate_prompt(sample, return_response=False)

In [None]:
input_ids = tokenizer(prompt[0], return_tensors="pt", truncation=True).input_ids.cuda()

outputs = model.generate(input_ids=input_ids, max_new_tokens=100, do_sample=True, top_p=0.9, temperature=0.5)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [None]:
print(f"Prompt:\n{prompt[0]}\n")
print(f"-------------")
print(f"Generated instruction:\n{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt[0]):]}")
print(f"-------------")
print(f"Ground truth:\n{sample['instruction']}")

Prompt:
Generate a simple instruction that could result in the provided context.[INST]CONTEXT: No. 
Stalemate is a drawn position. It doesn't matter who has captured more pieces or is in a winning position[/INST]

-------------
Generated instruction:
 Instruction: Declare the outcome when there are no legal moves left for either player.

Context: In a stalemate, the game ends as a draw. It does not matter who has captured more pieces or is in a winning position.
-------------
Ground truth:
If I have more pieces at the time of stalemate, have I won?


In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

untuned_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    use_cache=False,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
input_ids = tokenizer(prompt[0], return_tensors="pt", truncation=True).input_ids.cuda()

outputs = untuned_model.generate(input_ids=input_ids, max_new_tokens=100, do_sample=True, top_p=0.9, temperature=0.5)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [None]:
print(f"Prompt:\n{prompt}\n")
print(f"-------------")
print(f"Generated instruction:\n{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt[0]):]}")
print(f"-------------")
print(f"Ground truth:\n{sample['instruction']}")

Prompt:
["Generate a simple instruction that could result in the provided context.[INST]CONTEXT: No. \nStalemate is a drawn position. It doesn't matter who has captured more pieces or is in a winning position[/INST]"]

-------------
Generated instruction:
 If the game has reached a point where neither player can make a legal move that would result in an advantage, then the game is in a stalemate position. In this case, the game is declared a draw.
-------------
Ground truth:
If I have more pieces at the time of stalemate, have I won?


This notebook is adapted from the notebook developed by AI Makerspace