## Using Unsloth to finetune

In [1]:

import os 
import sys 

from datasets import Dataset

from dotenv import find_dotenv, load_dotenv

load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


True

Leveraging Unsloth notebooks for finetuning

In [2]:
max_seq_length = 16000 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.


In [3]:
from unsloth import FastLanguageModel
import torch
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Llama-3.2-3B-Instruct", # or choose "unsloth/Llama-3.2-1B-Instruct"
    # model_name = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
==((====))==  Unsloth 2024.10.6: Fast Llama patching. Transformers = 4.44.2.
   \\   /|    GPU: NVIDIA GeForce RTX 3080 Ti. Max memory: 11.753 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.4.1. CUDA = 8.6. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.28.post1. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


Unsloth: We fixed a gradient accumulation bug, but it seems like you don't have the latest transformers version!
Please update transformers, TRL and unsloth via:
`pip install --upgrade --no-cache-dir unsloth git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git`


In [4]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = True,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

Unsloth 2024.10.6 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


In [5]:
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3.1",
)

def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
    return { "text" : texts, }
pass

from datasets import load_dataset


## Get dataset

In [6]:
dataset_finetune = load_dataset("CPSC532/arxiv_qa_data",
                                name="test_dataset_2024OCT23",
                                split="train",
                                token=os.getenv('HUGGINGFACE_API_KEY'),
                                )


In [7]:
dataset_finetune

Dataset({
    features: ['question', 'answer', 'source'],
    num_rows: 85
})

In [8]:

dataset_finetune['question'][0]

"What is the purpose of using prompt tuning in the framework described in the paper 'Visual prompt tuning'?"

In [9]:

dataset_finetune['answer'][0]

'According to the paper "Visual Prompt Tuning for Generative Transfer Learning" [1], the primary purpose of using prompt tuning in their framework is to adapt a pre-trained generative vision transformer model to a new target distribution or domain with minimal additional training data.\n\nPrompt tuning involves prepending learnable tokens called prompts to the input sequence of visual tokens, which guides the pre-trained transformer model to generate images that conform to the target distribution. The prompt parameters are learned via gradient descent while keeping the pre-trained transformer parameters frozen.\n\nThe authors argue that prompt tuning is a more efficient and effective way to adapt generative vision transformers to new domains compared to other transfer learning methods such as full fine-tuning or adapter tuning. They also claim that prompt tuning allows for better control over image generation, enabling the model to produce diverse and high-quality images.\n\nIn particu

Convert dataset to messages format

In [10]:
def convert_to_messages_format(example):
    return [
        {"role": "user", "content": example['question']},
        {"role": "assistant", "content": example['answer']},
    ]

In [11]:
dataset_finetune = dataset_finetune.map(
    lambda x: {
        'conversations' : convert_to_messages_format(x)
        }
)

In [12]:
dataset_finetune = dataset_finetune.map(formatting_prompts_func, batched = True,)

Map: 100%|██████████| 85/85 [00:00<00:00, 5174.47 examples/s]


In [13]:
dataset_finetune['text'][0]

'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 July 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the purpose of using prompt tuning in the framework described in the paper \'Visual prompt tuning\'?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nAccording to the paper "Visual Prompt Tuning for Generative Transfer Learning" [1], the primary purpose of using prompt tuning in their framework is to adapt a pre-trained generative vision transformer model to a new target distribution or domain with minimal additional training data.\n\nPrompt tuning involves prepending learnable tokens called prompts to the input sequence of visual tokens, which guides the pre-trained transformer model to generate images that conform to the target distribution. The prompt parameters are learned via gradient descent while keeping the pre-trained transformer parameters frozen.\n\nThe authors argue that prom

## Dataset generated, now finetune the model

In [14]:
from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from unsloth import is_bfloat16_supported

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset_finetune,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
    dataset_num_proc = 1,  # Affects memory usage
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_train_batch_size = 2, # Affects memory usage
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        num_train_epochs = 20, # Set this for 1 full training run.
        # max_steps = 60,
        learning_rate = 2e-4,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
    ),
)

Map: 100%|██████████| 85/85 [00:00<00:00, 4195.88 examples/s]


We also use Unsloth's `train_on_completions` method to only train on the assistant outputs and ignore the loss on the user's inputs. Look into this

In [15]:
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
    response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
)

Map: 100%|██████████| 85/85 [00:00<00:00, 4981.85 examples/s]


In [16]:
tokenizer.decode(trainer.train_dataset[0]["input_ids"])

'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 July 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the purpose of using prompt tuning in the framework described in the paper \'Visual prompt tuning\'?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nAccording to the paper "Visual Prompt Tuning for Generative Transfer Learning" [1], the primary purpose of using prompt tuning in their framework is to adapt a pre-trained generative vision transformer model to a new target distribution or domain with minimal additional training data.\n\nPrompt tuning involves prepending learnable tokens called prompts to the input sequence of visual tokens, which guides the pre-trained transformer model to generate images that conform to the target distribution. The prompt parameters are learned via gradient descent while keeping the pre-trained transformer parameters frozen.\n\nThe authors argue that prom

In [17]:
space = tokenizer(" ", add_special_tokens = False).input_ids[0]
tokenizer.decode([space if x == -100 else x for x in trainer.train_dataset[5]["labels"]])

'                                                    \n\nThe authors of the paper "Visual Prompt Tuning for Generative Transfer Learning" are:\n\n1. Kihyuk Sohn\n2. Huiwen Chang\n3. Jos´e Lezama\n4. Luisa Polania\n5. Han Zhang\n6. Yuan Hao\n7. Irfan Essa\n8. Lu Jiang\n\nThey are affiliated with Google Research.<|eot_id|>'

In [18]:
#@title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA GeForce RTX 3080 Ti. Max memory = 11.753 GB.
3.275 GB of memory reserved.


In [19]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 85 | Num Epochs = 20
O^O/ \_/ \    Batch size per device = 2 | Gradient Accumulation steps = 4
\        /    Total batch size = 8 | Total steps = 200
 "-____-"     Number of trainable parameters = 194,510,848


**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!
`pip install --upgrade --no-cache-dir unsloth git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git`


  0%|          | 1/200 [00:02<09:50,  2.97s/it]

{'loss': 1.4582, 'grad_norm': 1.5018233060836792, 'learning_rate': 4e-05, 'epoch': 0.09}


  1%|          | 2/200 [00:05<09:18,  2.82s/it]

{'loss': 1.7517, 'grad_norm': 1.6050418615341187, 'learning_rate': 8e-05, 'epoch': 0.19}


  2%|▏         | 3/200 [00:08<09:10,  2.79s/it]

{'loss': 1.3805, 'grad_norm': 1.1010969877243042, 'learning_rate': 0.00012, 'epoch': 0.28}


  2%|▏         | 4/200 [00:10<08:24,  2.57s/it]

{'loss': 1.3147, 'grad_norm': 0.9064140319824219, 'learning_rate': 0.00016, 'epoch': 0.37}


  2%|▎         | 5/200 [00:17<13:03,  4.02s/it]

{'loss': 1.2346, 'grad_norm': 0.8346613049507141, 'learning_rate': 0.0002, 'epoch': 0.47}


  3%|▎         | 6/200 [00:19<11:13,  3.47s/it]

{'loss': 1.1619, 'grad_norm': 0.7336437702178955, 'learning_rate': 0.00019897435897435898, 'epoch': 0.56}


  4%|▎         | 7/200 [00:22<10:21,  3.22s/it]

{'loss': 1.0989, 'grad_norm': 0.7720448970794678, 'learning_rate': 0.00019794871794871796, 'epoch': 0.65}


  4%|▍         | 8/200 [00:25<09:57,  3.11s/it]

{'loss': 1.0583, 'grad_norm': 0.7950204610824585, 'learning_rate': 0.00019692307692307696, 'epoch': 0.74}


  4%|▍         | 9/200 [00:27<09:30,  2.98s/it]

{'loss': 1.2406, 'grad_norm': 0.8256083130836487, 'learning_rate': 0.0001958974358974359, 'epoch': 0.84}


  5%|▌         | 10/200 [00:31<09:43,  3.07s/it]

{'loss': 0.9475, 'grad_norm': 0.6062178611755371, 'learning_rate': 0.00019487179487179487, 'epoch': 0.93}


  6%|▌         | 11/200 [00:33<09:08,  2.90s/it]

{'loss': 0.7704, 'grad_norm': 0.7757887840270996, 'learning_rate': 0.00019384615384615385, 'epoch': 1.02}


  6%|▌         | 12/200 [00:36<08:51,  2.83s/it]

{'loss': 0.5516, 'grad_norm': 0.567615807056427, 'learning_rate': 0.00019282051282051282, 'epoch': 1.12}


  6%|▋         | 13/200 [00:39<09:01,  2.89s/it]

{'loss': 0.6509, 'grad_norm': 0.6499739289283752, 'learning_rate': 0.00019179487179487182, 'epoch': 1.21}


  7%|▋         | 14/200 [00:42<08:39,  2.80s/it]

{'loss': 0.7302, 'grad_norm': 0.7864968180656433, 'learning_rate': 0.0001907692307692308, 'epoch': 1.3}


  8%|▊         | 15/200 [00:48<11:56,  3.87s/it]

{'loss': 1.0545, 'grad_norm': 0.9858948588371277, 'learning_rate': 0.00018974358974358974, 'epoch': 1.4}


  8%|▊         | 16/200 [00:51<10:43,  3.50s/it]

{'loss': 0.6305, 'grad_norm': 0.6885298490524292, 'learning_rate': 0.0001887179487179487, 'epoch': 1.49}


  8%|▊         | 17/200 [00:53<09:57,  3.27s/it]

{'loss': 0.6546, 'grad_norm': 0.6791409850120544, 'learning_rate': 0.0001876923076923077, 'epoch': 1.58}


  9%|▉         | 18/200 [00:56<09:08,  3.02s/it]

{'loss': 0.7503, 'grad_norm': 0.7460134625434875, 'learning_rate': 0.0001866666666666667, 'epoch': 1.67}


 10%|▉         | 19/200 [00:59<09:38,  3.20s/it]

{'loss': 0.5798, 'grad_norm': 0.6023604869842529, 'learning_rate': 0.00018564102564102566, 'epoch': 1.77}


 10%|█         | 20/200 [01:02<09:08,  3.05s/it]

{'loss': 0.5704, 'grad_norm': 0.6868792176246643, 'learning_rate': 0.00018461538461538463, 'epoch': 1.86}


 10%|█         | 21/200 [01:05<08:49,  2.96s/it]

{'loss': 0.527, 'grad_norm': 0.6539011001586914, 'learning_rate': 0.00018358974358974358, 'epoch': 1.95}


 11%|█         | 22/200 [01:07<08:25,  2.84s/it]

{'loss': 0.3293, 'grad_norm': 0.715949535369873, 'learning_rate': 0.00018256410256410258, 'epoch': 2.05}


 12%|█▏        | 23/200 [01:14<11:46,  3.99s/it]

{'loss': 0.6369, 'grad_norm': 0.6234850883483887, 'learning_rate': 0.00018153846153846155, 'epoch': 2.14}


 12%|█▏        | 24/200 [01:16<10:19,  3.52s/it]

{'loss': 0.3421, 'grad_norm': 0.6229376196861267, 'learning_rate': 0.00018051282051282052, 'epoch': 2.23}


 12%|█▎        | 25/200 [01:19<09:28,  3.25s/it]

{'loss': 0.2232, 'grad_norm': 0.5256103277206421, 'learning_rate': 0.0001794871794871795, 'epoch': 2.33}


 13%|█▎        | 26/200 [01:22<08:47,  3.03s/it]

{'loss': 0.3414, 'grad_norm': 0.7033838033676147, 'learning_rate': 0.00017846153846153847, 'epoch': 2.42}


 14%|█▎        | 27/200 [01:24<08:31,  2.96s/it]

{'loss': 0.3134, 'grad_norm': 0.7388937473297119, 'learning_rate': 0.00017743589743589744, 'epoch': 2.51}


 14%|█▍        | 28/200 [01:27<08:20,  2.91s/it]

{'loss': 0.1858, 'grad_norm': 0.5216419100761414, 'learning_rate': 0.00017641025641025642, 'epoch': 2.6}


 14%|█▍        | 29/200 [01:30<08:24,  2.95s/it]

{'loss': 0.2666, 'grad_norm': 0.6766833662986755, 'learning_rate': 0.0001753846153846154, 'epoch': 2.7}


 15%|█▌        | 30/200 [01:33<07:55,  2.80s/it]

{'loss': 0.2692, 'grad_norm': 0.6553313136100769, 'learning_rate': 0.00017435897435897436, 'epoch': 2.79}


 16%|█▌        | 31/200 [01:35<07:44,  2.75s/it]

{'loss': 0.3266, 'grad_norm': 0.7390658259391785, 'learning_rate': 0.00017333333333333334, 'epoch': 2.88}


 16%|█▌        | 32/200 [01:38<08:03,  2.88s/it]

{'loss': 0.2673, 'grad_norm': 0.5776841044425964, 'learning_rate': 0.00017230769230769234, 'epoch': 2.98}


 16%|█▋        | 33/200 [01:41<07:40,  2.76s/it]

{'loss': 0.1763, 'grad_norm': 0.5795505046844482, 'learning_rate': 0.00017128205128205128, 'epoch': 3.07}


 17%|█▋        | 34/200 [01:48<10:55,  3.95s/it]

{'loss': 0.3444, 'grad_norm': 0.4800586998462677, 'learning_rate': 0.00017025641025641026, 'epoch': 3.16}


 18%|█▊        | 35/200 [01:51<09:59,  3.63s/it]

{'loss': 0.0982, 'grad_norm': 0.3815898597240448, 'learning_rate': 0.00016923076923076923, 'epoch': 3.26}


 18%|█▊        | 36/200 [01:53<08:39,  3.17s/it]

{'loss': 0.1395, 'grad_norm': 0.6245248913764954, 'learning_rate': 0.00016820512820512823, 'epoch': 3.35}


 18%|█▊        | 37/200 [01:55<07:54,  2.91s/it]

{'loss': 0.0974, 'grad_norm': 0.6435441970825195, 'learning_rate': 0.0001671794871794872, 'epoch': 3.44}


 19%|█▉        | 38/200 [01:58<07:41,  2.85s/it]

{'loss': 0.146, 'grad_norm': 0.7584537267684937, 'learning_rate': 0.00016615384615384617, 'epoch': 3.53}


 20%|█▉        | 39/200 [02:01<07:46,  2.90s/it]

{'loss': 0.0899, 'grad_norm': 0.492807000875473, 'learning_rate': 0.00016512820512820512, 'epoch': 3.63}


 20%|██        | 40/200 [02:03<07:34,  2.84s/it]

{'loss': 0.1301, 'grad_norm': 0.6575872302055359, 'learning_rate': 0.0001641025641025641, 'epoch': 3.72}


 20%|██        | 41/200 [02:07<07:49,  2.95s/it]

{'loss': 0.1368, 'grad_norm': 0.5697668194770813, 'learning_rate': 0.0001630769230769231, 'epoch': 3.81}


 21%|██        | 42/200 [02:09<07:38,  2.90s/it]

{'loss': 0.1503, 'grad_norm': 0.5209698677062988, 'learning_rate': 0.00016205128205128207, 'epoch': 3.91}


 22%|██▏       | 43/200 [02:12<07:02,  2.69s/it]

{'loss': 0.1248, 'grad_norm': 0.6573649644851685, 'learning_rate': 0.00016102564102564104, 'epoch': 4.0}


 22%|██▏       | 44/200 [02:14<06:44,  2.59s/it]

{'loss': 0.052, 'grad_norm': 0.37081319093704224, 'learning_rate': 0.00016, 'epoch': 4.09}


 22%|██▎       | 45/200 [02:16<06:41,  2.59s/it]

{'loss': 0.0544, 'grad_norm': 0.3630959987640381, 'learning_rate': 0.00015897435897435896, 'epoch': 4.19}


 23%|██▎       | 46/200 [02:19<06:52,  2.68s/it]

{'loss': 0.0665, 'grad_norm': 0.39253997802734375, 'learning_rate': 0.00015794871794871796, 'epoch': 4.28}


 24%|██▎       | 47/200 [02:22<06:59,  2.74s/it]

{'loss': 0.0648, 'grad_norm': 0.378837525844574, 'learning_rate': 0.00015692307692307693, 'epoch': 4.37}


 24%|██▍       | 48/200 [02:25<06:52,  2.71s/it]

{'loss': 0.099, 'grad_norm': 0.6500683426856995, 'learning_rate': 0.0001558974358974359, 'epoch': 4.47}


 24%|██▍       | 49/200 [02:28<07:10,  2.85s/it]

{'loss': 0.0446, 'grad_norm': 0.2938386797904968, 'learning_rate': 0.00015487179487179488, 'epoch': 4.56}


 25%|██▌       | 50/200 [02:31<07:16,  2.91s/it]

{'loss': 0.0576, 'grad_norm': 0.32962414622306824, 'learning_rate': 0.00015384615384615385, 'epoch': 4.65}


 26%|██▌       | 51/200 [02:38<09:57,  4.01s/it]

{'loss': 0.2053, 'grad_norm': 0.5714598894119263, 'learning_rate': 0.00015282051282051282, 'epoch': 4.74}


 26%|██▌       | 52/200 [02:41<09:05,  3.68s/it]

{'loss': 0.0884, 'grad_norm': 0.5070093274116516, 'learning_rate': 0.0001517948717948718, 'epoch': 4.84}


 26%|██▋       | 53/200 [02:43<08:07,  3.32s/it]

{'loss': 0.0721, 'grad_norm': 0.544553279876709, 'learning_rate': 0.00015076923076923077, 'epoch': 4.93}


 27%|██▋       | 54/200 [02:45<07:16,  2.99s/it]

{'loss': 0.0931, 'grad_norm': 0.6617953181266785, 'learning_rate': 0.00014974358974358974, 'epoch': 5.02}


 28%|██▊       | 55/200 [02:48<06:52,  2.85s/it]

{'loss': 0.0377, 'grad_norm': 0.2911519706249237, 'learning_rate': 0.00014871794871794872, 'epoch': 5.12}


 28%|██▊       | 56/200 [02:51<06:46,  2.82s/it]

{'loss': 0.0263, 'grad_norm': 0.2136574685573578, 'learning_rate': 0.00014769230769230772, 'epoch': 5.21}


 28%|██▊       | 57/200 [02:53<06:41,  2.81s/it]

{'loss': 0.0419, 'grad_norm': 0.2722257971763611, 'learning_rate': 0.00014666666666666666, 'epoch': 5.3}


 29%|██▉       | 58/200 [02:56<06:16,  2.65s/it]

{'loss': 0.0225, 'grad_norm': 0.1897166520357132, 'learning_rate': 0.00014564102564102564, 'epoch': 5.4}


 30%|██▉       | 59/200 [02:59<06:22,  2.71s/it]

{'loss': 0.0371, 'grad_norm': 0.29243403673171997, 'learning_rate': 0.0001446153846153846, 'epoch': 5.49}


 30%|███       | 60/200 [03:05<08:55,  3.82s/it]

{'loss': 0.1602, 'grad_norm': 0.3854168951511383, 'learning_rate': 0.0001435897435897436, 'epoch': 5.58}


 30%|███       | 61/200 [03:08<08:06,  3.50s/it]

{'loss': 0.0429, 'grad_norm': 0.3508232831954956, 'learning_rate': 0.00014256410256410258, 'epoch': 5.67}


 31%|███       | 62/200 [03:11<07:37,  3.31s/it]

{'loss': 0.0324, 'grad_norm': 0.2852233350276947, 'learning_rate': 0.00014153846153846156, 'epoch': 5.77}


 32%|███▏      | 63/200 [03:14<07:43,  3.38s/it]

{'loss': 0.0321, 'grad_norm': 0.26035091280937195, 'learning_rate': 0.0001405128205128205, 'epoch': 5.86}


 32%|███▏      | 64/200 [03:17<07:21,  3.24s/it]

{'loss': 0.0439, 'grad_norm': 0.27070868015289307, 'learning_rate': 0.00013948717948717947, 'epoch': 5.95}


 32%|███▎      | 65/200 [03:19<06:40,  2.97s/it]

{'loss': 0.0273, 'grad_norm': 0.29936856031417847, 'learning_rate': 0.00013846153846153847, 'epoch': 6.05}


 33%|███▎      | 66/200 [03:23<06:54,  3.09s/it]

{'loss': 0.0284, 'grad_norm': 0.2025476098060608, 'learning_rate': 0.00013743589743589745, 'epoch': 6.14}


 34%|███▎      | 67/200 [03:30<09:23,  4.24s/it]

{'loss': 0.1086, 'grad_norm': 0.1967364102602005, 'learning_rate': 0.00013641025641025642, 'epoch': 6.23}


 34%|███▍      | 68/200 [03:32<08:03,  3.66s/it]

{'loss': 0.0165, 'grad_norm': 0.32549014687538147, 'learning_rate': 0.0001353846153846154, 'epoch': 6.33}


 34%|███▍      | 69/200 [03:35<07:19,  3.35s/it]

{'loss': 0.0231, 'grad_norm': 0.26717352867126465, 'learning_rate': 0.00013435897435897437, 'epoch': 6.42}


 35%|███▌      | 70/200 [03:37<06:40,  3.08s/it]

{'loss': 0.023, 'grad_norm': 0.26833927631378174, 'learning_rate': 0.00013333333333333334, 'epoch': 6.51}


 36%|███▌      | 71/200 [03:39<06:10,  2.87s/it]

{'loss': 0.0185, 'grad_norm': 0.25442755222320557, 'learning_rate': 0.0001323076923076923, 'epoch': 6.6}


 36%|███▌      | 72/200 [03:42<06:08,  2.88s/it]

{'loss': 0.0183, 'grad_norm': 0.23648111522197723, 'learning_rate': 0.00013128205128205129, 'epoch': 6.7}


 36%|███▋      | 73/200 [03:45<06:06,  2.88s/it]

{'loss': 0.0155, 'grad_norm': 0.33141010999679565, 'learning_rate': 0.00013025641025641026, 'epoch': 6.79}


 37%|███▋      | 74/200 [03:48<06:00,  2.86s/it]

{'loss': 0.0187, 'grad_norm': 0.24655170738697052, 'learning_rate': 0.00012923076923076923, 'epoch': 6.88}


 38%|███▊      | 75/200 [03:51<05:53,  2.83s/it]

{'loss': 0.0202, 'grad_norm': 0.3576166331768036, 'learning_rate': 0.00012820512820512823, 'epoch': 6.98}


 38%|███▊      | 76/200 [03:53<05:32,  2.68s/it]

{'loss': 0.0119, 'grad_norm': 0.13945092260837555, 'learning_rate': 0.00012717948717948718, 'epoch': 7.07}


 38%|███▊      | 77/200 [03:56<05:30,  2.69s/it]

{'loss': 0.0075, 'grad_norm': 0.14763854444026947, 'learning_rate': 0.00012615384615384615, 'epoch': 7.16}


 39%|███▉      | 78/200 [03:59<05:35,  2.75s/it]

{'loss': 0.0131, 'grad_norm': 0.22325804829597473, 'learning_rate': 0.00012512820512820512, 'epoch': 7.26}


 40%|███▉      | 79/200 [04:02<05:39,  2.81s/it]

{'loss': 0.0111, 'grad_norm': 0.1044083833694458, 'learning_rate': 0.00012410256410256412, 'epoch': 7.35}


 40%|████      | 80/200 [04:05<05:50,  2.92s/it]

{'loss': 0.011, 'grad_norm': 0.109303779900074, 'learning_rate': 0.0001230769230769231, 'epoch': 7.44}


 40%|████      | 81/200 [04:08<05:41,  2.87s/it]

{'loss': 0.0072, 'grad_norm': 0.1768026500940323, 'learning_rate': 0.00012205128205128207, 'epoch': 7.53}


 41%|████      | 82/200 [04:14<07:43,  3.93s/it]

{'loss': 0.0792, 'grad_norm': 0.30397579073905945, 'learning_rate': 0.00012102564102564103, 'epoch': 7.63}


 42%|████▏     | 83/200 [04:17<06:55,  3.55s/it]

{'loss': 0.0136, 'grad_norm': 0.20253361761569977, 'learning_rate': 0.00012, 'epoch': 7.72}


 42%|████▏     | 84/200 [04:19<06:13,  3.22s/it]

{'loss': 0.0082, 'grad_norm': 0.12955781817436218, 'learning_rate': 0.00011897435897435898, 'epoch': 7.81}


 42%|████▎     | 85/200 [04:22<05:57,  3.11s/it]

{'loss': 0.0124, 'grad_norm': 0.16612285375595093, 'learning_rate': 0.00011794871794871796, 'epoch': 7.91}


 43%|████▎     | 86/200 [04:25<05:44,  3.02s/it]

{'loss': 0.0137, 'grad_norm': 0.25113245844841003, 'learning_rate': 0.00011692307692307694, 'epoch': 8.0}


 44%|████▎     | 87/200 [04:28<05:34,  2.96s/it]

{'loss': 0.0044, 'grad_norm': 0.05713268369436264, 'learning_rate': 0.00011589743589743591, 'epoch': 8.09}


 44%|████▍     | 88/200 [04:34<07:41,  4.12s/it]

{'loss': 0.0456, 'grad_norm': 0.17577870190143585, 'learning_rate': 0.00011487179487179487, 'epoch': 8.19}


 44%|████▍     | 89/200 [04:37<06:39,  3.60s/it]

{'loss': 0.0099, 'grad_norm': 0.17167045176029205, 'learning_rate': 0.00011384615384615384, 'epoch': 8.28}


 45%|████▌     | 90/200 [04:39<06:03,  3.30s/it]

{'loss': 0.0041, 'grad_norm': 0.09494686871767044, 'learning_rate': 0.00011282051282051283, 'epoch': 8.37}


 46%|████▌     | 91/200 [04:42<05:49,  3.21s/it]

{'loss': 0.0073, 'grad_norm': 0.0970773920416832, 'learning_rate': 0.0001117948717948718, 'epoch': 8.47}


 46%|████▌     | 92/200 [04:46<05:55,  3.29s/it]

{'loss': 0.0047, 'grad_norm': 0.07189386337995529, 'learning_rate': 0.00011076923076923077, 'epoch': 8.56}


 46%|████▋     | 93/200 [04:49<05:33,  3.11s/it]

{'loss': 0.0131, 'grad_norm': 0.1393193155527115, 'learning_rate': 0.00010974358974358976, 'epoch': 8.65}


 47%|████▋     | 94/200 [04:51<05:15,  2.97s/it]

{'loss': 0.0082, 'grad_norm': 0.22997477650642395, 'learning_rate': 0.00010871794871794872, 'epoch': 8.74}


 48%|████▊     | 95/200 [04:54<04:57,  2.83s/it]

{'loss': 0.0061, 'grad_norm': 0.1513374000787735, 'learning_rate': 0.0001076923076923077, 'epoch': 8.84}


 48%|████▊     | 96/200 [04:56<04:51,  2.80s/it]

{'loss': 0.0144, 'grad_norm': 0.1311728060245514, 'learning_rate': 0.00010666666666666667, 'epoch': 8.93}


 48%|████▊     | 97/200 [04:59<04:35,  2.68s/it]

{'loss': 0.0071, 'grad_norm': 0.13462676107883453, 'learning_rate': 0.00010564102564102565, 'epoch': 9.02}


 49%|████▉     | 98/200 [05:01<04:31,  2.67s/it]

{'loss': 0.0186, 'grad_norm': 0.26313260197639465, 'learning_rate': 0.00010461538461538463, 'epoch': 9.12}


 50%|████▉     | 99/200 [05:05<04:42,  2.79s/it]

{'loss': 0.0066, 'grad_norm': 0.12373708188533783, 'learning_rate': 0.0001035897435897436, 'epoch': 9.21}


 50%|█████     | 100/200 [05:07<04:31,  2.72s/it]

{'loss': 0.0032, 'grad_norm': 0.04324878752231598, 'learning_rate': 0.00010256410256410256, 'epoch': 9.3}


 50%|█████     | 101/200 [05:10<04:30,  2.74s/it]

{'loss': 0.0065, 'grad_norm': 0.1343567967414856, 'learning_rate': 0.00010153846153846153, 'epoch': 9.4}


 51%|█████     | 102/200 [05:13<04:37,  2.83s/it]

{'loss': 0.0049, 'grad_norm': 0.0775635614991188, 'learning_rate': 0.00010051282051282052, 'epoch': 9.49}


 52%|█████▏    | 103/200 [05:19<06:18,  3.91s/it]

{'loss': 0.0249, 'grad_norm': 0.1607249230146408, 'learning_rate': 9.948717948717949e-05, 'epoch': 9.58}


 52%|█████▏    | 104/200 [05:22<05:44,  3.59s/it]

{'loss': 0.0057, 'grad_norm': 0.07577455788850784, 'learning_rate': 9.846153846153848e-05, 'epoch': 9.67}


 52%|█████▎    | 105/200 [05:25<05:11,  3.28s/it]

{'loss': 0.0107, 'grad_norm': 0.21845968067646027, 'learning_rate': 9.743589743589744e-05, 'epoch': 9.77}


 53%|█████▎    | 106/200 [05:28<04:58,  3.17s/it]

{'loss': 0.0099, 'grad_norm': 0.1782718449831009, 'learning_rate': 9.641025641025641e-05, 'epoch': 9.86}


 54%|█████▎    | 107/200 [05:30<04:40,  3.02s/it]

{'loss': 0.0046, 'grad_norm': 0.08627781271934509, 'learning_rate': 9.53846153846154e-05, 'epoch': 9.95}


 54%|█████▍    | 108/200 [05:33<04:33,  2.97s/it]

{'loss': 0.0024, 'grad_norm': 0.03811035305261612, 'learning_rate': 9.435897435897436e-05, 'epoch': 10.05}


 55%|█████▍    | 109/200 [05:40<06:25,  4.23s/it]

{'loss': 0.0136, 'grad_norm': 0.07291316986083984, 'learning_rate': 9.333333333333334e-05, 'epoch': 10.14}


 55%|█████▌    | 110/200 [05:43<05:46,  3.85s/it]

{'loss': 0.0043, 'grad_norm': 0.15007193386554718, 'learning_rate': 9.230769230769232e-05, 'epoch': 10.23}


 56%|█████▌    | 111/200 [05:46<05:07,  3.45s/it]

{'loss': 0.0063, 'grad_norm': 0.06909658014774323, 'learning_rate': 9.128205128205129e-05, 'epoch': 10.33}


 56%|█████▌    | 112/200 [05:49<04:53,  3.34s/it]

{'loss': 0.0049, 'grad_norm': 0.14035001397132874, 'learning_rate': 9.025641025641026e-05, 'epoch': 10.42}


 56%|█████▋    | 113/200 [05:52<04:32,  3.13s/it]

{'loss': 0.0041, 'grad_norm': 0.07069804519414902, 'learning_rate': 8.923076923076924e-05, 'epoch': 10.51}


 57%|█████▋    | 114/200 [05:54<04:21,  3.04s/it]

{'loss': 0.0037, 'grad_norm': 0.08397391438484192, 'learning_rate': 8.820512820512821e-05, 'epoch': 10.6}


 57%|█████▊    | 115/200 [05:57<04:14,  3.00s/it]

{'loss': 0.0014, 'grad_norm': 0.01807965524494648, 'learning_rate': 8.717948717948718e-05, 'epoch': 10.7}


 58%|█████▊    | 116/200 [06:00<03:58,  2.84s/it]

{'loss': 0.0061, 'grad_norm': 0.10249766707420349, 'learning_rate': 8.615384615384617e-05, 'epoch': 10.79}


 58%|█████▊    | 117/200 [06:02<03:46,  2.72s/it]

{'loss': 0.0056, 'grad_norm': 0.10130567103624344, 'learning_rate': 8.512820512820513e-05, 'epoch': 10.88}


 59%|█████▉    | 118/200 [06:04<03:30,  2.56s/it]

{'loss': 0.0037, 'grad_norm': 0.07008056342601776, 'learning_rate': 8.410256410256411e-05, 'epoch': 10.98}


 60%|█████▉    | 119/200 [06:07<03:22,  2.50s/it]

{'loss': 0.0041, 'grad_norm': 0.19834066927433014, 'learning_rate': 8.307692307692309e-05, 'epoch': 11.07}


 60%|██████    | 120/200 [06:09<03:22,  2.54s/it]

{'loss': 0.0084, 'grad_norm': 0.09804251790046692, 'learning_rate': 8.205128205128205e-05, 'epoch': 11.16}


 60%|██████    | 121/200 [06:16<05:05,  3.86s/it]

{'loss': 0.0072, 'grad_norm': 0.0612361803650856, 'learning_rate': 8.102564102564103e-05, 'epoch': 11.26}


 61%|██████    | 122/200 [06:19<04:35,  3.53s/it]

{'loss': 0.0018, 'grad_norm': 0.09286966919898987, 'learning_rate': 8e-05, 'epoch': 11.35}


 62%|██████▏   | 123/200 [06:22<04:14,  3.31s/it]

{'loss': 0.0024, 'grad_norm': 0.040552977472543716, 'learning_rate': 7.897435897435898e-05, 'epoch': 11.44}


 62%|██████▏   | 124/200 [06:25<04:00,  3.17s/it]

{'loss': 0.0027, 'grad_norm': 0.044371046125888824, 'learning_rate': 7.794871794871795e-05, 'epoch': 11.53}


 62%|██████▎   | 125/200 [06:27<03:43,  2.98s/it]

{'loss': 0.0046, 'grad_norm': 0.12718509137630463, 'learning_rate': 7.692307692307693e-05, 'epoch': 11.63}


 63%|██████▎   | 126/200 [06:30<03:32,  2.87s/it]

{'loss': 0.0061, 'grad_norm': 0.1417839676141739, 'learning_rate': 7.58974358974359e-05, 'epoch': 11.72}


 64%|██████▎   | 127/200 [06:33<03:25,  2.81s/it]

{'loss': 0.0039, 'grad_norm': 0.06262122839689255, 'learning_rate': 7.487179487179487e-05, 'epoch': 11.81}


 64%|██████▍   | 128/200 [06:35<03:20,  2.79s/it]

{'loss': 0.0058, 'grad_norm': 0.13425715267658234, 'learning_rate': 7.384615384615386e-05, 'epoch': 11.91}


 64%|██████▍   | 129/200 [06:38<03:24,  2.89s/it]

{'loss': 0.0036, 'grad_norm': 0.05051178112626076, 'learning_rate': 7.282051282051282e-05, 'epoch': 12.0}


 65%|██████▌   | 130/200 [06:41<03:20,  2.87s/it]

{'loss': 0.0022, 'grad_norm': 0.032915763556957245, 'learning_rate': 7.17948717948718e-05, 'epoch': 12.09}


 66%|██████▌   | 131/200 [06:44<03:17,  2.86s/it]

{'loss': 0.0019, 'grad_norm': 0.026584874838590622, 'learning_rate': 7.076923076923078e-05, 'epoch': 12.19}


 66%|██████▌   | 132/200 [06:46<03:04,  2.72s/it]

{'loss': 0.002, 'grad_norm': 0.036501213908195496, 'learning_rate': 6.974358974358974e-05, 'epoch': 12.28}


 66%|██████▋   | 133/200 [06:49<02:55,  2.62s/it]

{'loss': 0.01, 'grad_norm': 0.0975642055273056, 'learning_rate': 6.871794871794872e-05, 'epoch': 12.37}


 67%|██████▋   | 134/200 [06:52<02:56,  2.68s/it]

{'loss': 0.0038, 'grad_norm': 0.07339199632406235, 'learning_rate': 6.76923076923077e-05, 'epoch': 12.47}


 68%|██████▊   | 135/200 [06:59<04:26,  4.10s/it]

{'loss': 0.0059, 'grad_norm': 0.0689399465918541, 'learning_rate': 6.666666666666667e-05, 'epoch': 12.56}


 68%|██████▊   | 136/200 [07:01<03:50,  3.60s/it]

{'loss': 0.003, 'grad_norm': 0.06283795833587646, 'learning_rate': 6.564102564102564e-05, 'epoch': 12.65}


 68%|██████▊   | 137/200 [07:04<03:30,  3.34s/it]

{'loss': 0.0019, 'grad_norm': 0.038202520459890366, 'learning_rate': 6.461538461538462e-05, 'epoch': 12.74}


 69%|██████▉   | 138/200 [07:07<03:16,  3.17s/it]

{'loss': 0.0024, 'grad_norm': 0.045345768332481384, 'learning_rate': 6.358974358974359e-05, 'epoch': 12.84}


 70%|██████▉   | 139/200 [07:09<02:58,  2.92s/it]

{'loss': 0.0024, 'grad_norm': 0.04016054794192314, 'learning_rate': 6.256410256410256e-05, 'epoch': 12.93}


 70%|███████   | 140/200 [07:12<02:49,  2.83s/it]

{'loss': 0.001, 'grad_norm': 0.01804950460791588, 'learning_rate': 6.153846153846155e-05, 'epoch': 13.02}


 70%|███████   | 141/200 [07:15<02:43,  2.78s/it]

{'loss': 0.002, 'grad_norm': 0.024483295157551765, 'learning_rate': 6.0512820512820515e-05, 'epoch': 13.12}


 71%|███████   | 142/200 [07:18<02:42,  2.81s/it]

{'loss': 0.0015, 'grad_norm': 0.02188805676996708, 'learning_rate': 5.948717948717949e-05, 'epoch': 13.21}


 72%|███████▏  | 143/200 [07:21<02:45,  2.90s/it]

{'loss': 0.0031, 'grad_norm': 0.07938370853662491, 'learning_rate': 5.846153846153847e-05, 'epoch': 13.3}


 72%|███████▏  | 144/200 [07:23<02:33,  2.74s/it]

{'loss': 0.005, 'grad_norm': 0.060798630118370056, 'learning_rate': 5.7435897435897434e-05, 'epoch': 13.4}


 72%|███████▎  | 145/200 [07:26<02:38,  2.88s/it]

{'loss': 0.0017, 'grad_norm': 0.02160460315644741, 'learning_rate': 5.6410256410256414e-05, 'epoch': 13.49}


 73%|███████▎  | 146/200 [07:29<02:29,  2.78s/it]

{'loss': 0.0011, 'grad_norm': 0.01405599620193243, 'learning_rate': 5.538461538461539e-05, 'epoch': 13.58}


 74%|███████▎  | 147/200 [07:36<03:31,  3.99s/it]

{'loss': 0.004, 'grad_norm': 0.03473062068223953, 'learning_rate': 5.435897435897436e-05, 'epoch': 13.67}


 74%|███████▍  | 148/200 [07:38<03:01,  3.48s/it]

{'loss': 0.0025, 'grad_norm': 0.06826993823051453, 'learning_rate': 5.333333333333333e-05, 'epoch': 13.77}


 74%|███████▍  | 149/200 [07:41<02:45,  3.25s/it]

{'loss': 0.0024, 'grad_norm': 0.02991563454270363, 'learning_rate': 5.230769230769231e-05, 'epoch': 13.86}


 75%|███████▌  | 150/200 [07:43<02:32,  3.06s/it]

{'loss': 0.0053, 'grad_norm': 0.05996502935886383, 'learning_rate': 5.128205128205128e-05, 'epoch': 13.95}


 76%|███████▌  | 151/200 [07:46<02:21,  2.89s/it]

{'loss': 0.0049, 'grad_norm': 0.08636781573295593, 'learning_rate': 5.025641025641026e-05, 'epoch': 14.05}


 76%|███████▌  | 152/200 [07:49<02:22,  2.97s/it]

{'loss': 0.0013, 'grad_norm': 0.022909022867679596, 'learning_rate': 4.923076923076924e-05, 'epoch': 14.14}


 76%|███████▋  | 153/200 [07:52<02:21,  3.01s/it]

{'loss': 0.0012, 'grad_norm': 0.01568475551903248, 'learning_rate': 4.8205128205128205e-05, 'epoch': 14.23}


 77%|███████▋  | 154/200 [07:55<02:13,  2.90s/it]

{'loss': 0.0014, 'grad_norm': 0.02180536277592182, 'learning_rate': 4.717948717948718e-05, 'epoch': 14.33}


 78%|███████▊  | 155/200 [07:57<02:06,  2.82s/it]

{'loss': 0.0018, 'grad_norm': 0.023408323526382446, 'learning_rate': 4.615384615384616e-05, 'epoch': 14.42}


 78%|███████▊  | 156/200 [08:04<02:54,  3.97s/it]

{'loss': 0.0034, 'grad_norm': 0.025052141398191452, 'learning_rate': 4.512820512820513e-05, 'epoch': 14.51}


 78%|███████▊  | 157/200 [08:07<02:35,  3.62s/it]

{'loss': 0.0017, 'grad_norm': 0.028578510507941246, 'learning_rate': 4.4102564102564104e-05, 'epoch': 14.6}


 79%|███████▉  | 158/200 [08:09<02:20,  3.35s/it]

{'loss': 0.0013, 'grad_norm': 0.015503151342272758, 'learning_rate': 4.3076923076923084e-05, 'epoch': 14.7}


 80%|███████▉  | 159/200 [08:12<02:06,  3.09s/it]

{'loss': 0.004, 'grad_norm': 0.06676548719406128, 'learning_rate': 4.205128205128206e-05, 'epoch': 14.79}


 80%|████████  | 160/200 [08:14<01:55,  2.90s/it]

{'loss': 0.0015, 'grad_norm': 0.021608004346489906, 'learning_rate': 4.1025641025641023e-05, 'epoch': 14.88}


 80%|████████  | 161/200 [08:17<01:48,  2.77s/it]

{'loss': 0.004, 'grad_norm': 0.05002662166953087, 'learning_rate': 4e-05, 'epoch': 14.98}


 81%|████████  | 162/200 [08:19<01:37,  2.58s/it]

{'loss': 0.0019, 'grad_norm': 0.03623126819729805, 'learning_rate': 3.8974358974358976e-05, 'epoch': 15.07}


 82%|████████▏ | 163/200 [08:22<01:35,  2.59s/it]

{'loss': 0.0011, 'grad_norm': 0.012760031968355179, 'learning_rate': 3.794871794871795e-05, 'epoch': 15.16}


 82%|████████▏ | 164/200 [08:24<01:37,  2.70s/it]

{'loss': 0.0021, 'grad_norm': 0.039339080452919006, 'learning_rate': 3.692307692307693e-05, 'epoch': 15.26}


 82%|████████▎ | 165/200 [08:27<01:33,  2.66s/it]

{'loss': 0.001, 'grad_norm': 0.012633893638849258, 'learning_rate': 3.58974358974359e-05, 'epoch': 15.35}


 83%|████████▎ | 166/200 [08:34<02:09,  3.80s/it]

{'loss': 0.0046, 'grad_norm': 0.036416683346033096, 'learning_rate': 3.487179487179487e-05, 'epoch': 15.44}


 84%|████████▎ | 167/200 [08:37<01:59,  3.63s/it]

{'loss': 0.0008, 'grad_norm': 0.007314203307032585, 'learning_rate': 3.384615384615385e-05, 'epoch': 15.53}


 84%|████████▍ | 168/200 [08:39<01:46,  3.34s/it]

{'loss': 0.0016, 'grad_norm': 0.02520889602601528, 'learning_rate': 3.282051282051282e-05, 'epoch': 15.63}


 84%|████████▍ | 169/200 [08:42<01:37,  3.16s/it]

{'loss': 0.0011, 'grad_norm': 0.013913018628954887, 'learning_rate': 3.1794871794871795e-05, 'epoch': 15.72}


 85%|████████▌ | 170/200 [08:45<01:31,  3.04s/it]

{'loss': 0.0023, 'grad_norm': 0.03339759260416031, 'learning_rate': 3.0769230769230774e-05, 'epoch': 15.81}


 86%|████████▌ | 171/200 [08:48<01:24,  2.91s/it]

{'loss': 0.0011, 'grad_norm': 0.020315932109951973, 'learning_rate': 2.9743589743589744e-05, 'epoch': 15.91}


 86%|████████▌ | 172/200 [08:50<01:18,  2.81s/it]

{'loss': 0.0015, 'grad_norm': 0.023565514013171196, 'learning_rate': 2.8717948717948717e-05, 'epoch': 16.0}


 86%|████████▋ | 173/200 [08:53<01:13,  2.74s/it]

{'loss': 0.0014, 'grad_norm': 0.02369954250752926, 'learning_rate': 2.7692307692307694e-05, 'epoch': 16.09}


 87%|████████▋ | 174/200 [08:55<01:10,  2.70s/it]

{'loss': 0.0006, 'grad_norm': 0.005428728647530079, 'learning_rate': 2.6666666666666667e-05, 'epoch': 16.19}


 88%|████████▊ | 175/200 [08:58<01:06,  2.65s/it]

{'loss': 0.0009, 'grad_norm': 0.012587697245180607, 'learning_rate': 2.564102564102564e-05, 'epoch': 16.28}


 88%|████████▊ | 176/200 [09:00<01:01,  2.56s/it]

{'loss': 0.002, 'grad_norm': 0.03203953057527542, 'learning_rate': 2.461538461538462e-05, 'epoch': 16.37}


 88%|████████▊ | 177/200 [09:02<00:57,  2.48s/it]

{'loss': 0.0007, 'grad_norm': 0.006771191023290157, 'learning_rate': 2.358974358974359e-05, 'epoch': 16.47}


 89%|████████▉ | 178/200 [09:08<01:15,  3.43s/it]

{'loss': 0.003, 'grad_norm': 0.021312309429049492, 'learning_rate': 2.2564102564102566e-05, 'epoch': 16.56}


 90%|████████▉ | 179/200 [09:11<01:05,  3.12s/it]

{'loss': 0.0015, 'grad_norm': 0.0235032606869936, 'learning_rate': 2.1538461538461542e-05, 'epoch': 16.65}


 90%|█████████ | 180/200 [09:14<01:01,  3.08s/it]

{'loss': 0.0009, 'grad_norm': 0.010025713592767715, 'learning_rate': 2.0512820512820512e-05, 'epoch': 16.74}


 90%|█████████ | 181/200 [09:16<00:54,  2.88s/it]

{'loss': 0.0023, 'grad_norm': 0.03636343032121658, 'learning_rate': 1.9487179487179488e-05, 'epoch': 16.84}


 91%|█████████ | 182/200 [09:19<00:50,  2.80s/it]

{'loss': 0.0016, 'grad_norm': 0.021754274144768715, 'learning_rate': 1.8461538461538465e-05, 'epoch': 16.93}


 92%|█████████▏| 183/200 [09:21<00:45,  2.68s/it]

{'loss': 0.0017, 'grad_norm': 0.030189065262675285, 'learning_rate': 1.7435897435897434e-05, 'epoch': 17.02}


 92%|█████████▏| 184/200 [09:24<00:42,  2.66s/it]

{'loss': 0.0007, 'grad_norm': 0.007285867352038622, 'learning_rate': 1.641025641025641e-05, 'epoch': 17.12}


 92%|█████████▎| 185/200 [09:26<00:39,  2.64s/it]

{'loss': 0.001, 'grad_norm': 0.020153498277068138, 'learning_rate': 1.5384615384615387e-05, 'epoch': 17.21}


 93%|█████████▎| 186/200 [09:29<00:36,  2.58s/it]

{'loss': 0.0019, 'grad_norm': 0.02207956649363041, 'learning_rate': 1.4358974358974359e-05, 'epoch': 17.3}


 94%|█████████▎| 187/200 [09:31<00:33,  2.56s/it]

{'loss': 0.0009, 'grad_norm': 0.01317972969263792, 'learning_rate': 1.3333333333333333e-05, 'epoch': 17.4}


 94%|█████████▍| 188/200 [09:34<00:30,  2.58s/it]

{'loss': 0.0016, 'grad_norm': 0.02557852305471897, 'learning_rate': 1.230769230769231e-05, 'epoch': 17.49}


 94%|█████████▍| 189/200 [09:36<00:27,  2.49s/it]

{'loss': 0.002, 'grad_norm': 0.030933335423469543, 'learning_rate': 1.1282051282051283e-05, 'epoch': 17.58}


 95%|█████████▌| 190/200 [09:38<00:23,  2.36s/it]

{'loss': 0.0005, 'grad_norm': 0.005254069343209267, 'learning_rate': 1.0256410256410256e-05, 'epoch': 17.67}


 96%|█████████▌| 191/200 [09:41<00:22,  2.53s/it]

{'loss': 0.002, 'grad_norm': 0.029395749792456627, 'learning_rate': 9.230769230769232e-06, 'epoch': 17.77}


 96%|█████████▌| 192/200 [09:47<00:28,  3.53s/it]

{'loss': 0.0021, 'grad_norm': 0.011168799363076687, 'learning_rate': 8.205128205128205e-06, 'epoch': 17.86}


 96%|█████████▋| 193/200 [09:49<00:22,  3.17s/it]

{'loss': 0.0018, 'grad_norm': 0.024058330804109573, 'learning_rate': 7.179487179487179e-06, 'epoch': 17.95}


 97%|█████████▋| 194/200 [09:51<00:17,  2.84s/it]

{'loss': 0.001, 'grad_norm': 0.018414976075291634, 'learning_rate': 6.153846153846155e-06, 'epoch': 18.05}


 98%|█████████▊| 195/200 [09:53<00:12,  2.58s/it]

{'loss': 0.0009, 'grad_norm': 0.018139643594622612, 'learning_rate': 5.128205128205128e-06, 'epoch': 18.14}


 98%|█████████▊| 196/200 [09:56<00:10,  2.59s/it]

{'loss': 0.0009, 'grad_norm': 0.020029248669743538, 'learning_rate': 4.102564102564103e-06, 'epoch': 18.23}


 98%|█████████▊| 197/200 [09:58<00:07,  2.58s/it]

{'loss': 0.0007, 'grad_norm': 0.008204185403883457, 'learning_rate': 3.0769230769230774e-06, 'epoch': 18.33}


 99%|█████████▉| 198/200 [10:01<00:04,  2.50s/it]

{'loss': 0.0018, 'grad_norm': 0.021906156092882156, 'learning_rate': 2.0512820512820513e-06, 'epoch': 18.42}


100%|█████████▉| 199/200 [10:03<00:02,  2.42s/it]

{'loss': 0.0012, 'grad_norm': 0.01932643912732601, 'learning_rate': 1.0256410256410257e-06, 'epoch': 18.51}


100%|██████████| 200/200 [10:05<00:00,  2.34s/it]

{'loss': 0.001, 'grad_norm': 0.01832760125398636, 'learning_rate': 0.0, 'epoch': 18.6}


100%|██████████| 200/200 [10:07<00:00,  3.04s/it]

{'train_runtime': 607.1811, 'train_samples_per_second': 2.8, 'train_steps_per_second': 0.329, 'train_loss': 0.13805641622602707, 'epoch': 18.6}





In [20]:
#@title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory         /max_memory*100, 3)
lora_percentage = round(used_memory_for_lora/max_memory*100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

607.1811 seconds used for training.
10.12 minutes used for training.
Peak reserved memory = 10.723 GB.
Peak reserved memory for training = 7.448 GB.
Peak reserved memory % of max memory = 91.236 %.
Peak reserved memory for training % of max memory = 63.371 %.


## Run Inference

In [25]:
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3.1",
)
FastLanguageModel.for_inference(model) # Enable native 2x faster inference

def get_response(user_query):
    messages = [
    {"role": "user", "content": user_query},
    ]
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize = True,
        add_generation_prompt = True, # Must add for generation
        return_tensors = "pt",
    ).to("cuda")

    outputs = model.generate(input_ids = inputs, max_new_tokens = 64, use_cache = True,
                            temperature = 1.5, min_p = 0.1)
    return tokenizer.batch_decode(outputs)

In [26]:
dataset_finetune['question'][0]

"What is the purpose of using prompt tuning in the framework described in the paper 'Visual prompt tuning'?"

Need to investigate how changing the question affects responses

In [27]:
resp = get_response(dataset_finetune['question'][0])
print(resp[0].split("<|start_header_id|>assistant<|end_header_id|>")[1])



According to the paper "Visual Prompt Tuning for Generative Transfer Learning" [1], the primary purpose of using prompt tuning in their framework is to adapt a pre-trained generative vision transformer model to a new target distribution or domain with minimal additional training data.

Prompt tuning involves prepending learnable tokens called prompts to the


## Save to HF

In [28]:
model.push_to_hub_gguf(
        "CPSC532/finetuned_model", # Change hf to your username!
        tokenizer,
        
        quantization_method = ["q4_k_m", "q8_0", "q5_k_m",],
        token = os.getenv("HUGGINGFACE_API_KEY"), # Get a token at https://huggingface.co/settings/tokens
    )

Unsloth: Merging 4bit and LoRA weights to 16bit...
Unsloth: Will use up to 32.83 out of 62.67 RAM for saving.


100%|██████████| 28/28 [00:00<00:00, 90.82it/s]


Unsloth: Saving tokenizer... Done.
Unsloth: Saving model... This might take 5 minutes for Llama-7b...
Done.
==((====))==  Unsloth: Conversion from QLoRA to GGUF information
   \\   /|    [0] Installing llama.cpp will take 3 minutes.
O^O/ \_/ \    [1] Converting HF to GGUF 16bits will take 3 minutes.
\        /    [2] Converting GGUF 16bits to ['q4_k_m', 'q8_0', 'q5_k_m'] will take 10 minutes each.
 "-____-"     In total, you will have to wait at least 16 minutes.

Unsloth: [0] Installing llama.cpp. This will take 3 minutes...
Unsloth: [1] Converting model at CPSC532/finetuned_model into bf16 GGUF format.
The output location will be /home/owen/Desktop/github/532/implementation/finetuning/CPSC532/finetuned_model/unsloth.BF16.gguf
This will take 3 minutes...
INFO:hf-to-gguf:Loading model: finetuned_model
INFO:gguf.gguf_writer:gguf: This GGUF file is for Little Endian only
INFO:hf-to-gguf:Exporting model...
INFO:hf-to-gguf:rope_freqs.weight,           torch.float32 --> F32, shape = {64}
IN

No files have been modified since last commit. Skipping to prevent empty commit.
No files have been modified since last commit. Skipping to prevent empty commit.


Saved GGUF to https://huggingface.co/CPSC532/finetuned_model
Unsloth: Uploading GGUF to Huggingface Hub...


No files have been modified since last commit. Skipping to prevent empty commit.
No files have been modified since last commit. Skipping to prevent empty commit.


Saved GGUF to https://huggingface.co/CPSC532/finetuned_model
Unsloth: Uploading GGUF to Huggingface Hub...


No files have been modified since last commit. Skipping to prevent empty commit.
No files have been modified since last commit. Skipping to prevent empty commit.


Saved GGUF to https://huggingface.co/CPSC532/finetuned_model


No files have been modified since last commit. Skipping to prevent empty commit.
No files have been modified since last commit. Skipping to prevent empty commit.


Saved Ollama Modelfile to https://huggingface.co/CPSC532/finetuned_model
