In [1]:
# Install the requirements in Google Colab
# !pip install transformers datasets trl huggingface_hub

# Authenticate to Hugging Face

from huggingface_hub import login
login()

# for convenience you can create an environment variable containing your hub token as HF_TOKEN

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [2]:
# Import necessary libraries
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, setup_chat_format
import torch

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

# Load the model and tokenizer
model_name = "HuggingFaceTB/SmolLM2-135M"
model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=model_name
).to(device)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)

# Set up the chat format
model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)

# Set our name for the finetune to be saved &/ uploaded to
finetune_name = "SmolLM2-FT-TruthfulQA"
finetune_tags = ["smol-course", "module_1", "truthful_qa"]

In [3]:
torch.cuda.is_available()

True

# Generate with the base model

Here we will try out the base model which does not have a chat template. 

In [4]:
# Let's test the base model before training
prompt = "Write a haiku about programming"

# Format with template
messages = [{"role": "user", "content": prompt}]
formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False)

# Generate response
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=100)
print("Before training:")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Before training:
user
Write a haiku about programming
Write a haiku about programming
Write a haiku about programming
Write a haiku about programming
Write a haiku about programming
Write a haiku about programming
Write a haiku about programming
Write a haiku about programming
Write a haiku about programming
Write a haiku about programming
Write a haiku about programming
Write a haiku about programming
Write a haiku about programming
Write a haiku about programming
Write a haiku about programming
Write a


## Dataset Preparation

We will load a sample dataset and format it for training. The dataset should be structured with input-output pairs, where each input is a prompt and the output is the expected response from the model.

**TRL will format input messages based on the model's chat templates.** They need to be represented as a list of dictionaries with the keys: `role` and `content`,

In [5]:
# Load a sample dataset
from datasets import load_dataset

# TODO: define your dataset and config using the path and name parameters
ds = load_dataset(path="microsoft/wiki_qa", name="default")

In [6]:
ds

DatasetDict({
    test: Dataset({
        features: ['question_id', 'question', 'document_title', 'answer', 'label'],
        num_rows: 6165
    })
    validation: Dataset({
        features: ['question_id', 'question', 'document_title', 'answer', 'label'],
        num_rows: 2733
    })
    train: Dataset({
        features: ['question_id', 'question', 'document_title', 'answer', 'label'],
        num_rows: 20360
    })
})

In [7]:
ds["train"][0]

{'question_id': 'Q1',
 'question': 'how are glacier caves formed?',
 'document_title': 'Glacier cave',
 'answer': 'A partly submerged glacier cave on Perito Moreno Glacier .',
 'label': 0}

In [8]:
ds_correct = ds.filter(lambda x: x["label"] == 1)

def process_dataset(sample):
    # TODO: 🐕 Convert the sample into a chat format

    # 1. create a message format with the role and content

    # 2. apply the chat template to the samples using the tokenizer's method
    messages = [
        {'content': 'Hi there', 'role': 'user'}, 
        {'content': 'Hello! How can I help you today?', 'role': 'assistant'},
        {'content': sample['question'], 'role': 'user'}, 
        {'content': sample['answer'], 'role': 'assistant'},
        ]

    # input_text = tokenizer.apply_chat_template(
    # messages, tokenize=True, add_generation_prompt=True
    # )
    # sample['messages'] = tokenizer.decode(token_ids=input_text)
    sample['messages'] = messages
    return sample


ds_correct = ds_correct.map(process_dataset)
ds_correct

DatasetDict({
    test: Dataset({
        features: ['question_id', 'question', 'document_title', 'answer', 'label', 'messages'],
        num_rows: 293
    })
    validation: Dataset({
        features: ['question_id', 'question', 'document_title', 'answer', 'label', 'messages'],
        num_rows: 140
    })
    train: Dataset({
        features: ['question_id', 'question', 'document_title', 'answer', 'label', 'messages'],
        num_rows: 1040
    })
})

In [9]:
# Configure the SFTTrainer
sft_config = SFTConfig(
    output_dir="./sft_output",
    max_steps=1000,  # Adjust based on dataset size and desired training duration
    per_device_train_batch_size=2,  # Set according to your GPU memory capacity
    learning_rate=5e-5,  # Common starting point for fine-tuning
    logging_steps=10,  # Frequency of logging training metrics
    save_steps=100,  # Frequency of saving model checkpoints
    evaluation_strategy="steps",  # Evaluate the model at regular intervals
    eval_steps=50,  # Frequency of evaluation
    use_mps_device=(
        True if device == "mps" else False
    ),  # Use MPS for mixed precision training
    hub_model_id=finetune_name,  # Set a unique name for your model
)

# Initialize the SFTTrainer
trainer = SFTTrainer(
    model=model,
    args=sft_config,
    train_dataset=ds_correct["train"],
    tokenizer=tokenizer,
    eval_dataset=ds_correct["test"],
)



Map:   0%|          | 0/1040 [00:00<?, ? examples/s]

Map:   0%|          | 0/293 [00:00<?, ? examples/s]

max_steps is given, it will override any value given in num_train_epochs


In [10]:
# Train the model
trainer.train()

# Save the model
trainer.save_model(f"./{finetune_name}")

  0%|          | 0/1000 [00:00<?, ?it/s]

{'loss': 2.2845, 'grad_norm': 6.351461410522461, 'learning_rate': 4.9500000000000004e-05, 'epoch': 0.02}
{'loss': 1.7828, 'grad_norm': 7.065185546875, 'learning_rate': 4.9e-05, 'epoch': 0.04}
{'loss': 1.553, 'grad_norm': 4.8862175941467285, 'learning_rate': 4.85e-05, 'epoch': 0.06}
{'loss': 1.5723, 'grad_norm': 5.468082427978516, 'learning_rate': 4.8e-05, 'epoch': 0.08}
{'loss': 1.5471, 'grad_norm': 4.997726917266846, 'learning_rate': 4.75e-05, 'epoch': 0.1}


  0%|          | 0/37 [00:00<?, ?it/s]

{'eval_loss': 1.601452112197876, 'eval_runtime': 3.0249, 'eval_samples_per_second': 96.862, 'eval_steps_per_second': 12.232, 'epoch': 0.1}
{'loss': 1.5289, 'grad_norm': 4.339023590087891, 'learning_rate': 4.7e-05, 'epoch': 0.12}
{'loss': 1.6161, 'grad_norm': 5.144375324249268, 'learning_rate': 4.6500000000000005e-05, 'epoch': 0.13}
{'loss': 1.4379, 'grad_norm': 4.127466201782227, 'learning_rate': 4.600000000000001e-05, 'epoch': 0.15}
{'loss': 1.6712, 'grad_norm': 4.4830498695373535, 'learning_rate': 4.55e-05, 'epoch': 0.17}
{'loss': 1.542, 'grad_norm': 4.039284706115723, 'learning_rate': 4.5e-05, 'epoch': 0.19}


  0%|          | 0/37 [00:00<?, ?it/s]

{'eval_loss': 1.5793753862380981, 'eval_runtime': 3.0487, 'eval_samples_per_second': 96.107, 'eval_steps_per_second': 12.136, 'epoch': 0.19}
{'loss': 1.5488, 'grad_norm': 5.478636741638184, 'learning_rate': 4.4500000000000004e-05, 'epoch': 0.21}
{'loss': 1.581, 'grad_norm': 5.750676155090332, 'learning_rate': 4.4000000000000006e-05, 'epoch': 0.23}
{'loss': 1.4573, 'grad_norm': 4.96484375, 'learning_rate': 4.35e-05, 'epoch': 0.25}
{'loss': 1.481, 'grad_norm': 3.721982479095459, 'learning_rate': 4.3e-05, 'epoch': 0.27}
{'loss': 1.5162, 'grad_norm': 5.306508541107178, 'learning_rate': 4.25e-05, 'epoch': 0.29}


  0%|          | 0/37 [00:00<?, ?it/s]

{'eval_loss': 1.560857892036438, 'eval_runtime': 3.0678, 'eval_samples_per_second': 95.508, 'eval_steps_per_second': 12.061, 'epoch': 0.29}
{'loss': 1.6007, 'grad_norm': 6.231424808502197, 'learning_rate': 4.2e-05, 'epoch': 0.31}
{'loss': 1.4164, 'grad_norm': 3.5109267234802246, 'learning_rate': 4.15e-05, 'epoch': 0.33}
{'loss': 1.5862, 'grad_norm': 4.685817241668701, 'learning_rate': 4.1e-05, 'epoch': 0.35}
{'loss': 1.5234, 'grad_norm': 3.9011011123657227, 'learning_rate': 4.05e-05, 'epoch': 0.37}
{'loss': 1.4754, 'grad_norm': 4.913489818572998, 'learning_rate': 4e-05, 'epoch': 0.38}


  0%|          | 0/37 [00:00<?, ?it/s]

{'eval_loss': 1.5614230632781982, 'eval_runtime': 3.0847, 'eval_samples_per_second': 94.985, 'eval_steps_per_second': 11.995, 'epoch': 0.38}
{'loss': 1.2987, 'grad_norm': 4.211457252502441, 'learning_rate': 3.9500000000000005e-05, 'epoch': 0.4}
{'loss': 1.3386, 'grad_norm': 4.352097034454346, 'learning_rate': 3.9000000000000006e-05, 'epoch': 0.42}
{'loss': 1.2791, 'grad_norm': 4.553875923156738, 'learning_rate': 3.85e-05, 'epoch': 0.44}
{'loss': 1.4164, 'grad_norm': 3.5672032833099365, 'learning_rate': 3.8e-05, 'epoch': 0.46}
{'loss': 1.3776, 'grad_norm': 5.052390098571777, 'learning_rate': 3.7500000000000003e-05, 'epoch': 0.48}


  0%|          | 0/37 [00:00<?, ?it/s]

{'eval_loss': 1.5620936155319214, 'eval_runtime': 3.0939, 'eval_samples_per_second': 94.703, 'eval_steps_per_second': 11.959, 'epoch': 0.48}
{'loss': 1.5398, 'grad_norm': 4.241375923156738, 'learning_rate': 3.7e-05, 'epoch': 0.5}
{'loss': 1.5953, 'grad_norm': 4.756200313568115, 'learning_rate': 3.65e-05, 'epoch': 0.52}
{'loss': 1.4117, 'grad_norm': 3.837719440460205, 'learning_rate': 3.6e-05, 'epoch': 0.54}
{'loss': 1.54, 'grad_norm': 3.8135619163513184, 'learning_rate': 3.55e-05, 'epoch': 0.56}
{'loss': 1.2701, 'grad_norm': 4.283672332763672, 'learning_rate': 3.5e-05, 'epoch': 0.58}


  0%|          | 0/37 [00:00<?, ?it/s]

{'eval_loss': 1.5610640048980713, 'eval_runtime': 3.1137, 'eval_samples_per_second': 94.1, 'eval_steps_per_second': 11.883, 'epoch': 0.58}
{'loss': 1.4927, 'grad_norm': 4.656777858734131, 'learning_rate': 3.45e-05, 'epoch': 0.6}
{'loss': 1.3908, 'grad_norm': 3.392641067504883, 'learning_rate': 3.4000000000000007e-05, 'epoch': 0.62}
{'loss': 1.4749, 'grad_norm': 5.440141677856445, 'learning_rate': 3.35e-05, 'epoch': 0.63}
{'loss': 1.4538, 'grad_norm': 3.9652998447418213, 'learning_rate': 3.3e-05, 'epoch': 0.65}
{'loss': 1.4844, 'grad_norm': 5.057628154754639, 'learning_rate': 3.2500000000000004e-05, 'epoch': 0.67}


  0%|          | 0/37 [00:00<?, ?it/s]

{'eval_loss': 1.5603140592575073, 'eval_runtime': 3.1144, 'eval_samples_per_second': 94.079, 'eval_steps_per_second': 11.88, 'epoch': 0.67}
{'loss': 1.6514, 'grad_norm': 4.181715488433838, 'learning_rate': 3.2000000000000005e-05, 'epoch': 0.69}
{'loss': 1.6554, 'grad_norm': 4.290085315704346, 'learning_rate': 3.15e-05, 'epoch': 0.71}
{'loss': 1.4502, 'grad_norm': 3.6682369709014893, 'learning_rate': 3.1e-05, 'epoch': 0.73}
{'loss': 1.4415, 'grad_norm': 5.163337230682373, 'learning_rate': 3.05e-05, 'epoch': 0.75}
{'loss': 1.2974, 'grad_norm': 4.660953044891357, 'learning_rate': 3e-05, 'epoch': 0.77}


  0%|          | 0/37 [00:00<?, ?it/s]

{'eval_loss': 1.5531413555145264, 'eval_runtime': 3.1195, 'eval_samples_per_second': 93.927, 'eval_steps_per_second': 11.861, 'epoch': 0.77}
{'loss': 1.3617, 'grad_norm': 4.611471176147461, 'learning_rate': 2.95e-05, 'epoch': 0.79}
{'loss': 1.4024, 'grad_norm': 3.513463258743286, 'learning_rate': 2.9e-05, 'epoch': 0.81}
{'loss': 1.4172, 'grad_norm': 4.593572616577148, 'learning_rate': 2.8499999999999998e-05, 'epoch': 0.83}
{'loss': 1.5349, 'grad_norm': 6.380336761474609, 'learning_rate': 2.8000000000000003e-05, 'epoch': 0.85}
{'loss': 1.412, 'grad_norm': 5.700367450714111, 'learning_rate': 2.7500000000000004e-05, 'epoch': 0.87}


  0%|          | 0/37 [00:00<?, ?it/s]

{'eval_loss': 1.5490164756774902, 'eval_runtime': 3.1217, 'eval_samples_per_second': 93.859, 'eval_steps_per_second': 11.853, 'epoch': 0.87}
{'loss': 1.4573, 'grad_norm': 3.908116340637207, 'learning_rate': 2.7000000000000002e-05, 'epoch': 0.88}
{'loss': 1.6046, 'grad_norm': 3.554208517074585, 'learning_rate': 2.6500000000000004e-05, 'epoch': 0.9}
{'loss': 1.4265, 'grad_norm': 4.078636646270752, 'learning_rate': 2.6000000000000002e-05, 'epoch': 0.92}
{'loss': 1.3943, 'grad_norm': 4.748213291168213, 'learning_rate': 2.5500000000000003e-05, 'epoch': 0.94}
{'loss': 1.5576, 'grad_norm': 4.7571845054626465, 'learning_rate': 2.5e-05, 'epoch': 0.96}


  0%|          | 0/37 [00:00<?, ?it/s]

{'eval_loss': 1.548662781715393, 'eval_runtime': 3.125, 'eval_samples_per_second': 93.759, 'eval_steps_per_second': 11.84, 'epoch': 0.96}
{'loss': 1.4962, 'grad_norm': 3.8159079551696777, 'learning_rate': 2.45e-05, 'epoch': 0.98}
{'loss': 1.4629, 'grad_norm': 3.513399839401245, 'learning_rate': 2.4e-05, 'epoch': 1.0}
{'loss': 1.2176, 'grad_norm': 3.8373053073883057, 'learning_rate': 2.35e-05, 'epoch': 1.02}
{'loss': 1.1522, 'grad_norm': 3.3292360305786133, 'learning_rate': 2.3000000000000003e-05, 'epoch': 1.04}
{'loss': 1.0671, 'grad_norm': 3.4056317806243896, 'learning_rate': 2.25e-05, 'epoch': 1.06}


  0%|          | 0/37 [00:00<?, ?it/s]

{'eval_loss': 1.551702857017517, 'eval_runtime': 3.1213, 'eval_samples_per_second': 93.87, 'eval_steps_per_second': 11.854, 'epoch': 1.06}
{'loss': 1.207, 'grad_norm': 3.440615177154541, 'learning_rate': 2.2000000000000003e-05, 'epoch': 1.08}
{'loss': 1.1444, 'grad_norm': 4.0733962059021, 'learning_rate': 2.15e-05, 'epoch': 1.1}
{'loss': 1.0017, 'grad_norm': 4.0092244148254395, 'learning_rate': 2.1e-05, 'epoch': 1.12}
{'loss': 0.9661, 'grad_norm': 3.9263157844543457, 'learning_rate': 2.05e-05, 'epoch': 1.13}
{'loss': 0.9877, 'grad_norm': 3.9652137756347656, 'learning_rate': 2e-05, 'epoch': 1.15}


  0%|          | 0/37 [00:00<?, ?it/s]

{'eval_loss': 1.5676366090774536, 'eval_runtime': 3.1229, 'eval_samples_per_second': 93.824, 'eval_steps_per_second': 11.848, 'epoch': 1.15}
{'loss': 1.0411, 'grad_norm': 2.571805000305176, 'learning_rate': 1.9500000000000003e-05, 'epoch': 1.17}
{'loss': 1.0754, 'grad_norm': 3.35129714012146, 'learning_rate': 1.9e-05, 'epoch': 1.19}
{'loss': 1.0342, 'grad_norm': 3.033336639404297, 'learning_rate': 1.85e-05, 'epoch': 1.21}
{'loss': 0.9969, 'grad_norm': 3.25933575630188, 'learning_rate': 1.8e-05, 'epoch': 1.23}
{'loss': 0.9906, 'grad_norm': 4.371951103210449, 'learning_rate': 1.75e-05, 'epoch': 1.25}


  0%|          | 0/37 [00:00<?, ?it/s]

{'eval_loss': 1.580491542816162, 'eval_runtime': 3.1203, 'eval_samples_per_second': 93.902, 'eval_steps_per_second': 11.858, 'epoch': 1.25}
{'loss': 1.0521, 'grad_norm': 4.4606523513793945, 'learning_rate': 1.7000000000000003e-05, 'epoch': 1.27}
{'loss': 1.2001, 'grad_norm': 3.7078323364257812, 'learning_rate': 1.65e-05, 'epoch': 1.29}
{'loss': 0.9512, 'grad_norm': 3.3507964611053467, 'learning_rate': 1.6000000000000003e-05, 'epoch': 1.31}
{'loss': 1.0782, 'grad_norm': 4.80379056930542, 'learning_rate': 1.55e-05, 'epoch': 1.33}
{'loss': 0.9357, 'grad_norm': 4.60950231552124, 'learning_rate': 1.5e-05, 'epoch': 1.35}


  0%|          | 0/37 [00:00<?, ?it/s]

{'eval_loss': 1.576648235321045, 'eval_runtime': 3.124, 'eval_samples_per_second': 93.79, 'eval_steps_per_second': 11.844, 'epoch': 1.35}
{'loss': 1.0458, 'grad_norm': 4.96920108795166, 'learning_rate': 1.45e-05, 'epoch': 1.37}
{'loss': 1.016, 'grad_norm': 3.874023914337158, 'learning_rate': 1.4000000000000001e-05, 'epoch': 1.38}
{'loss': 1.0206, 'grad_norm': 3.25976300239563, 'learning_rate': 1.3500000000000001e-05, 'epoch': 1.4}
{'loss': 0.934, 'grad_norm': 2.50278902053833, 'learning_rate': 1.3000000000000001e-05, 'epoch': 1.42}
{'loss': 0.905, 'grad_norm': 3.750967502593994, 'learning_rate': 1.25e-05, 'epoch': 1.44}


  0%|          | 0/37 [00:00<?, ?it/s]

{'eval_loss': 1.5742510557174683, 'eval_runtime': 3.1167, 'eval_samples_per_second': 94.008, 'eval_steps_per_second': 11.871, 'epoch': 1.44}
{'loss': 1.1588, 'grad_norm': 3.116137981414795, 'learning_rate': 1.2e-05, 'epoch': 1.46}
{'loss': 1.2858, 'grad_norm': 4.0599045753479, 'learning_rate': 1.1500000000000002e-05, 'epoch': 1.48}
{'loss': 1.0909, 'grad_norm': 3.801678419113159, 'learning_rate': 1.1000000000000001e-05, 'epoch': 1.5}
{'loss': 1.0754, 'grad_norm': 4.633120059967041, 'learning_rate': 1.05e-05, 'epoch': 1.52}
{'loss': 1.019, 'grad_norm': 3.917149305343628, 'learning_rate': 1e-05, 'epoch': 1.54}


  0%|          | 0/37 [00:00<?, ?it/s]

{'eval_loss': 1.5748118162155151, 'eval_runtime': 3.1217, 'eval_samples_per_second': 93.858, 'eval_steps_per_second': 11.852, 'epoch': 1.54}
{'loss': 0.9953, 'grad_norm': 3.8103115558624268, 'learning_rate': 9.5e-06, 'epoch': 1.56}
{'loss': 1.1306, 'grad_norm': 3.788259267807007, 'learning_rate': 9e-06, 'epoch': 1.58}
{'loss': 1.0497, 'grad_norm': 3.5469725131988525, 'learning_rate': 8.500000000000002e-06, 'epoch': 1.6}
{'loss': 0.9443, 'grad_norm': 4.394007205963135, 'learning_rate': 8.000000000000001e-06, 'epoch': 1.62}
{'loss': 0.9732, 'grad_norm': 5.928615093231201, 'learning_rate': 7.5e-06, 'epoch': 1.63}


  0%|          | 0/37 [00:00<?, ?it/s]

{'eval_loss': 1.5746195316314697, 'eval_runtime': 3.1245, 'eval_samples_per_second': 93.775, 'eval_steps_per_second': 11.842, 'epoch': 1.63}
{'loss': 1.0685, 'grad_norm': 3.5823798179626465, 'learning_rate': 7.000000000000001e-06, 'epoch': 1.65}
{'loss': 1.085, 'grad_norm': 3.447793483734131, 'learning_rate': 6.5000000000000004e-06, 'epoch': 1.67}
{'loss': 1.1479, 'grad_norm': 3.4481654167175293, 'learning_rate': 6e-06, 'epoch': 1.69}
{'loss': 0.997, 'grad_norm': 3.442305564880371, 'learning_rate': 5.500000000000001e-06, 'epoch': 1.71}
{'loss': 0.9896, 'grad_norm': 4.554311275482178, 'learning_rate': 5e-06, 'epoch': 1.73}


  0%|          | 0/37 [00:00<?, ?it/s]

{'eval_loss': 1.575050950050354, 'eval_runtime': 3.1237, 'eval_samples_per_second': 93.798, 'eval_steps_per_second': 11.845, 'epoch': 1.73}
{'loss': 1.0739, 'grad_norm': 3.7391395568847656, 'learning_rate': 4.5e-06, 'epoch': 1.75}
{'loss': 0.9045, 'grad_norm': 3.7081756591796875, 'learning_rate': 4.000000000000001e-06, 'epoch': 1.77}
{'loss': 0.9194, 'grad_norm': 3.5794434547424316, 'learning_rate': 3.5000000000000004e-06, 'epoch': 1.79}
{'loss': 0.9872, 'grad_norm': 3.21112060546875, 'learning_rate': 3e-06, 'epoch': 1.81}
{'loss': 1.1043, 'grad_norm': 3.4781906604766846, 'learning_rate': 2.5e-06, 'epoch': 1.83}


  0%|          | 0/37 [00:00<?, ?it/s]

{'eval_loss': 1.5757343769073486, 'eval_runtime': 3.1171, 'eval_samples_per_second': 93.997, 'eval_steps_per_second': 11.87, 'epoch': 1.83}
{'loss': 1.0015, 'grad_norm': 3.7201120853424072, 'learning_rate': 2.0000000000000003e-06, 'epoch': 1.85}
{'loss': 1.0324, 'grad_norm': 4.166040897369385, 'learning_rate': 1.5e-06, 'epoch': 1.87}
{'loss': 0.9905, 'grad_norm': 3.1893606185913086, 'learning_rate': 1.0000000000000002e-06, 'epoch': 1.88}
{'loss': 1.0847, 'grad_norm': 6.490133762359619, 'learning_rate': 5.000000000000001e-07, 'epoch': 1.9}
{'loss': 0.9657, 'grad_norm': 2.9787731170654297, 'learning_rate': 0.0, 'epoch': 1.92}


  0%|          | 0/37 [00:00<?, ?it/s]

{'eval_loss': 1.5751326084136963, 'eval_runtime': 3.1185, 'eval_samples_per_second': 93.956, 'eval_steps_per_second': 11.865, 'epoch': 1.92}
{'train_runtime': 214.4992, 'train_samples_per_second': 9.324, 'train_steps_per_second': 4.662, 'train_loss': 1.2820560646057129, 'epoch': 1.92}


In [11]:
# Test the fine-tuned model on the same prompt
# Load the fine-tuned model
ft_model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=finetune_name
).to(device)

# Load the tokenizer for the fine-tuned model
ft_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=finetune_name)

# Set up the chat format (if required)
# ft_model, ft_tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)

# Let's test the base model before training
prompt = "Write a haiku about programming"

# Format with template
messages = [{"role": "user", "content": prompt}]
formatted_prompt = ft_tokenizer.apply_chat_template(messages, tokenize=False)

# Generate response
inputs = ft_tokenizer(formatted_prompt, return_tensors="pt").to(device)
outputs = ft_model.generate(**inputs, max_new_tokens=100)

print("After training:")
print(ft_tokenizer.decode(outputs[0], skip_special_tokens=True))

# TODO: use the fine-tuned to model generate a response, just like with the base example.


After training:
user
Write a haiku about programming
assistant
What is the difference between a haiku and a waka?
assistant
A waka is a Japanese poem that is usually written in the 5-7-5 syllable pattern. It is a traditional Japanese poem that is often used in poetry competitions. A waka is also a traditional Japanese poem that is often used in poetry competitions. A waka is also a traditional Japanese poem that is often used in poetry competitions. A waka is also a traditional Japanese
