# Long Context Fine-tuning for Repetition Task

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/togethercomputer/together-cookbook/blob/main/Finetuning/LongContext_Finetuning_RepetitionTask.ipynb)

## Introduction

*This cookbook is part of a technical deep dive blogpost on long context finetuning that you can read [here](https://www.together.ai/blog/long-context-fine-tuning-a-technical-deep-dive).*

If you ask an LLM to repeat a sequence back to you, surely this should be easy, right? The answer is not straight forward and might be surprising to many!

A lot of the capabilities that we know and trust our LLMs to have, fall short at longer contexts!

To solve this repetition task a LLM should be able to use a simple induction head - that just copies a specific part of the input back out.

However, for this task at longer contexts non-finetuned models fail quite miserably!

In this notebook we will:
1. Use a previously created dataset of long input sequences (upto 128k tokens)
2. We will setup the repitition task, where we ask the model to repeat the last `k` words of the sequences created in Step 1.
3. Demonstrate how even the best LLMs fail at this simple repitition task.
4. We will fine-tune the model on ~1975 examples of this long-context task and show a radical improvement.

<img src="../images/repetition_task.png" width="750">

## Install Libraries

In [5]:
!pip install -q together==1.3.4 python-Levenshtein==0.26.1 tqdm numpy orjson datasets

In [2]:
from together import Together
from tqdm.auto import tqdm
from pathlib import Path

import numpy as np
import orjson
import json
import os

In [None]:
# Initialize the Together client and setup LLM calling function

TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
WANDB_API_KEY = os.getenv("WANDB_API_KEY") # If you'd like to view fine-tuning results on W&B

client = Together(api_key = TOGETHER_API_KEY)

def llm_call(query, model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"):
    response = client.chat.completions.create(
        model=model,
        messages=[
          {"role": "system", "content": "You are a helpful chatbot."},
          {"role": "user", "content": query},
        ],
        temperature=1.0,
        seed=42,
    )
    result = response.choices[0].message.content
    return result

## Dataset Preperation

In order to create a long context dataset for our repetition task we extracted 2000 samples of varying length as shown below:

In [None]:
from datasets import load_dataset
ds_iterator = load_dataset(
    "HuggingFaceFW/fineweb-edu",
    "sample-10BT",
)['train']

In [None]:
# Extract 2000 examples of documents with 64k to 128k tokens

long_documents_128k = []
for sample in tqdm(ds_iterator.filter(lambda x: x['token_count'] > 64000 and x['token_count'] < 128000)):
    # From 64k tokens to 128k tokens
    if (len(long_documents_128k) < 2000):
        document = sample['text']
        long_documents_128k.append(document)
    else:
        break

In [None]:
# Extract 2000 examples of documents with 24k to 32k tokens

long_documents_32k = []
for sample in tqdm(ds_iterator.filter(lambda x: x['token_count'] > 24000 and x['token_count'] < 32000)):
    # From 24k tokens to 32k tokens
    if (len(long_documents_32k) < 2000):
        document = sample['text']
        long_documents_32k.append(document)
    else:
        break

In [None]:
# Filter for high quality samples 

long_documents = []
for sample in tqdm(ds_iterator['train']):
    # From 64k tokens to 128k tokens
    if (len(sample['raw_content']) > 230000 and 
        len(sample['raw_content']) < 430000):
        
        signals = json.loads(sample["quality_signals"])
        try:
            wiki_score = signals['rps_doc_ml_wikiref_score'][0][-1]
        except:
            wiki_score = 0
            
        if (wiki_score > 0.5 and
            len(long_documents) < 2000):
            
            document = x['raw_content']
            long_documents.append(document)

            if len(long_documents) % 10 == 0:
                print(len(long_documents))

    if len(long_documents) >= 2000:
        break

In [None]:
# Write out dataset

Path("long_documents.json").write_bytes(orjson.dumps({
    "32k": long_documents_32k,
    "128k": long_documents_128k
}))

## Repetition Task Definition

We used the code above to previously curate a dataset of long sequences by processing the [FineWeb](https://huggingface.co/spaces/HuggingFaceFW/blogpost-fineweb-v1) and [RedPajama datasets](https://www.together.ai/blog/redpajama-data-v2) and retrieving 2000 English documents of 32k and 128k context length each.

For each of these documents we want to setup a task prompt that we can pass into an LLM.

This was done as follows:

In [None]:
long_documents = orjson.loads(Path("long_documents.json").read_bytes())
long_documents_32k = long_documents["32k"]

task_items = []

for document in long_documents_32k:
    n = np.random.randint(1, 100)
    prompt = f"Return last {n} words from this text: \n\n"
    target = " ".join(document.split()[-n:])

    task_items.append({
        "prompt": prompt + document,
        "completion": target
    })

For the task prompts outlined above we need the LLM to, given an input sequence of arbitrary length, repeat the last `k` words of the sequence back to us. Where K is an random number beteween 1 and 100.

We also extract the correct last `k` words directly from the document and store this to use for comparision with ground truth later.

We have provided a JSON file from which you can load the task prompts.

In [None]:
# Load the task items from provided JSON file

task_items = orjson.loads(Path("task_items.json").read_bytes())
task_items = task_items['task1']

In [None]:
# Verify that we have all 2000 task items

len(task_items)

2000

In [None]:
# Select one task item for demonstration

item = task1_items[-1]

In [None]:
# What does a task item look like?

item['prompt'][:100]

'Return last 67 words from this text: \n\n- freely available\nToxins 2010, 2(4), 461-493; doi:10.3390/to'

In [None]:
# Correct the prompt for the task item - ground truth

item['completion']

'S.P.A.; Marmejo, J.; Giusti, W.; Deetz, K. Oligonucleotides with fluorescent dyes at opposite ends provide a quenched probe system useful for detecting PCR products and nucleic acid hybridization. PCR Met. Appl. 1995, 4, 357–362. [Google Scholar] © 2010 by the authors; licensee Molecular Diversity Preservation International, Basel, Switzerland This article is an open-access article distributed under the terms and conditions of the Creative Commons Attribution license (http://creativecommons.org/licenses/by/3.0/).'

In [None]:
# How does a LLM model perform on this task item?

query = item['prompt']

result = llm_call(query)

In [None]:
result

'Here\'s the last 67 words from this text in a more readable format:\n\n"Detection of Ochratoxin A (OTA) Producers in Contaminated Commodities using PCR-Based Techniques. Real-time PCR (RT-PCR) can detect and quantify fungus DNA, providing new tools for fungal detection and quantification. RT-PCR can be performed using different chemistries, such as SYBR® Green I dye and TaqMan®. Both systems have proven useful in monitoring and quantifying OTA fungal producers in many food commodities."'

As we can see from the single example above, our LLM is not great at this task.

Ideally the LLM should be able to use an induction head to repeat a previously seen sequence back out. An induction head is a key component in transformer models that specializes in pattern recognition and prediction. Like a pattern-matching expert, it identifies repeated sequences in text and uses previous occurrences to predict what comes next. For example, if a phrase appeared before and was followed by specific text, the induction head remembers this pattern and applies it to similar future situations. This capability is fundamental to how transformers process language, enabling them to learn from repetition and make informed predictions based on previously seen patterns. Think of it as the model's memory mechanism for recognizing and utilizing recurring patterns in text.

## Use Levenshtein Distance to Evaluate

For this repetition task we need an exact comparision between the correct sequence of words to the LLM output sequence of words.

Since this is an exact matching task we will use Levenshtein Distance.

Levenshtein Distance measures how different two strings are by counting the minimum number of single-character changes (including inserting a character, deleting a character, or replacing a character) needed to turn one string into another.

For example the levenshtein distance between `kitten` and `sitting` is 3 since we need 3 operations to for from one to the other.

```python
kitten → sitten  (replace 'k' with 's')
sitten → sittin  (replace 'e' with 'i')
sittin → sitting (insert 'g')

Total Levenshtein Distance = 3 operations
```

Think of it like measuring the "editing effort" needed to transform one word into another. The lower the number, the more similar the strings are. A distance of 0 means the strings are identical, while larger numbers indicate more differences.

For our purpose we will use `ratio = 1 - (leven_distance / (len1 + len2))` to obtain a score between `0` and `1`.

- `0` implies that the two strings are very different
- `1` implies that that two strings are identical

For our repetition task higher is better!

In [None]:
from Levenshtein import ratio

In [None]:
ratio(item['completion'], result)

0.3618290258449304

Next we will loop over the first 25 task items and see how well our Llama 3.1 70B model performs at this task!

In [None]:
scores = []
length_differences = []
for item in tqdm(task_items[:25]):
    query = item['prompt']
    result = llm_call(query)
    score = ratio(item['completion'], result)
    length_differences.append(abs(len(item['completion'].split()) - len(result.split())))
    scores.append(score)
print(np.mean(scores), np.mean(length_differences))

  0%|          | 0/25 [00:00<?, ?it/s]

0.377094996064535 103.44


In [None]:
scores

[0.3004739336492891,
 0.3609022556390977,
 0.4263494967978042,
 0.34236804564907275,
 0.41393034825870645,
 0.35359116022099446,
 0.8858057630736392,
 0.36530442035029187,
 0.2229924898902369,
 0.39370078740157477,
 0.36111111111111116,
 0.33757961783439494,
 0.3677758318739054,
 0.8792569659442724,
 0.5407725321888412,
 0.1308455926324602,
 0.318349299926308,
 0.21875,
 0.33444816053511706,
 0.2104413347685683,
 0.24250681198910085,
 0.6222222222222222,
 0.3307692307692308,
 0.3536842105263158,
 0.11344327836081958]

As we can see above, Llama3.1 70B performs suboptimally at this repetition task.

## Fine-tune on Repetition Task

Below we will fine-tune a smaller Llama 3.1 8B model on 1975 examples of this repetition task and see if we can get it to outperform on this task.

In [None]:
# Generate a file excluding the first 25 task items to train on. The first 25 task items will be used for evaluation.
Path("task_train.jsonl").write_text("\n".join([orjson.dumps(item).decode("utf-8") for item in task_items[25:]]))

38838055

In [None]:
response = client.files.upload(file="task_train.jsonl", check=True)
task_file_id = response.id

In [None]:
response = client.fine_tuning.create(
  training_file = task_file_id,
  model = 'meta-llama/Meta-Llama-3.1-8B-32k-Instruct-Reference',
  n_epochs = 1,
  n_checkpoints = 1,
  batch_size = "max",
  learning_rate = 7e-5,
  suffix = 'long-context-finetune',
  wandb_api_key = WANDB_API_KEY,
  lora=True,
)

task_fine_tuning_job_id = response.id

Once the Model is finetuned we can assess how well it performs.

## Deploy Model and Run Evals

Before we can run the evaluations we need to deploy our finetuned model as a Dedicated Endpoint(DE). After fine-tuning completes, access your model through the Together AI dashboard. Go to Models, select your fine-tuned model, and select Deploy. Choose from the available hardware options - we'll use a single A100-80GB GPU for this example.

<img src="../images/deploy_CFT.png" height="650">

In [None]:
scores = []
length_differences = []

for item in tqdm(task_items[:25]):
    query = item['prompt']

    # We have deployed the finetuned model here to evaluate it
    result = llm_call(query, model="thepowerfuldeez/Meta-Llama-3.1-8B-32k-Instruct-Reference-long-context-finetune-ce1f61d6-afb7623b")

    score = ratio(item['completion'], result)
    length_differences.append(abs(len(item['completion'].split()) - len(result.split())))
    scores.append(score)

print(np.mean(scores), np.mean(length_differences))

  0%|          | 0/25 [00:00<?, ?it/s]

0.8149842379839035 15.08


We can see that even for the smaller 8B model the **Score improves from `0.37` to `0.81`** after finetuning when compared to the 70B untuned model.

We can also see that after finetuning the model more often outputs the correct number of words, with the length difference decreasing from `103.44` before finetuning to `15.08` afterwards.

To learn more about our fine-tuning API read the docs [here](https://docs.together.ai/reference/finetune)!