# Router Fine-tuning

In this notebook, we experiment with fine-tuning an LLM-powered router. We try a few different approaches, with query + ground-truth "choice" as the training signal.

1. Fine-tuning embeddings
2. Fine-tuning a cross-encoder

Our dataset will be Wikipedia articles of different cities. 

We will generate a synthetic dataset for each approach to fine-tune over. We will also run some basic evaluations.

In [28]:
import nest_asyncio
nest_asyncio.apply()

In [22]:
!pip install spacy

Collecting spacy
  Downloading spacy-3.7.2-cp310-cp310-macosx_11_0_arm64.whl (6.6 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.6/6.6 MB[0m [31m29.5 MB/s[0m eta [36m0:00:00[0m31m31.1 MB/s[0m eta [36m0:00:01[0m
Collecting wasabi<1.2.0,>=0.9.1
  Downloading wasabi-1.1.2-py3-none-any.whl (27 kB)
Collecting smart-open<7.0.0,>=5.2.1
  Downloading smart_open-6.4.0-py3-none-any.whl (57 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.0/57.0 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting weasel<0.4.0,>=0.1.0
  Downloading weasel-0.3.3-py3-none-any.whl (49 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.8/49.8 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting murmurhash<1.1.0,>=0.28.0
  Downloading murmurhash-1.0.10-cp310-cp310-macosx_11_0_arm64.whl (26 kB)
Collecting preshed<3.1.0,>=3.0.2
  Downloading preshed-3.0.9-cp310-cp310-macosx_11_0_arm64.whl (

## Setup

In [1]:
wiki_titles = [
    "Toronto",
    "Seattle",
    "Chicago",
    "Boston",
    "Houston",
    "Tokyo",
    "Berlin",
    "Lisbon",
]

In [2]:
from pathlib import Path

import requests

for title in wiki_titles:
    response = requests.get(
        "https://en.wikipedia.org/w/api.php",
        params={
            "action": "query",
            "format": "json",
            "titles": title,
            "prop": "extracts",
            # 'exintro': True,
            "explaintext": True,
        },
    ).json()
    page = next(iter(response["query"]["pages"].values()))
    wiki_text = page["extract"]

    data_path = Path("data")
    if not data_path.exists():
        Path.mkdir(data_path)

    with open(data_path / f"{title}.txt", "w") as fp:
        fp.write(wiki_text)

In [4]:
from llama_index import SimpleDirectoryReader

# Load all wiki documents
city_docs = {}
for wiki_title in wiki_titles:
    city_docs[wiki_title] = SimpleDirectoryReader(
        input_files=[f"data/{wiki_title}.txt"]
    ).load_data()

In [5]:
from llama_index import ServiceContext
from llama_index.llms import OpenAI

gpt_35_context = ServiceContext.from_defaults(
    llm=OpenAI(model="gpt-3.5-turbo", temperature=0.3)
)

In [6]:
# define descriptions/choices for tools
city_descs_dict = {}
# these choices will be passed to the router selector 
choices = []
choice_to_id_dict = {}

for idx, wiki_title in enumerate(wiki_titles):
    vector_desc = (
        "Useful for questions related to specific aspects of"
        f" {wiki_title} (e.g. the history, arts and culture,"
        " sports, demographics, or more)."
    )
    summary_desc = (
        "Useful for any requests that require a holistic summary"
        f" of EVERYTHING about {wiki_title}. For questions about"
        " more specific sections, please use the vector_tool."
    )
    doc_id_vector = f"{wiki_title}_vector"
    doc_id_summary = f"{wiki_title}_summary"
    city_descs_dict[doc_id_vector] = vector_desc
    city_descs_dict[doc_id_summary] = summary_desc

    choices.extend([vector_desc, summary_desc])
    choice_to_id_dict[idx*2] = f"{wiki_title}_vector"
    choice_to_id_dict[idx*2+1] = f"{wiki_title}_summary"

In [10]:
from llama_index.llms import OpenAI
from llama_index.prompts import PromptTemplate

llm = OpenAI(model_name="gpt-3.5-turbo")

summary_q_tmpl = """\
You are a summary question generator. Given an existing question which asks for a summary of a given topic, \
generate {num_vary} related queries that also ask for a summary of the topic.

For example, assuming we're generating 3 related questions:
Base Question: Can you tell me more about Boston?
Question Variations:
Give me an overview of Boston as a city.
Can you describe different aspects of Boston, from the history to the sports scene to the food?
Write a concise summary of Boston; I've never been.

Now let's give it a shot! 

Base Question: {base_question}
Question Variations:
"""
summary_q_prompt = PromptTemplate(summary_q_tmpl)

In [59]:
from collections import defaultdict
from llama_index.evaluation import DatasetGenerator
from llama_index.finetuning import EmbeddingQAFinetuneDataset
from llama_index.node_parser import SimpleNodeParser
from tqdm.notebook import tqdm

def generate_dataset(
    wiki_titles, 
    city_descs_dict, 
    llm, 
    summary_q_prompt, 
    num_vector_qs_per_node=2,
    num_summary_qs=4
):
    # generate dataset from each wikipedia page
    queries = {}
    corpus = {}
    relevant_docs = defaultdict(list)
    for idx, wiki_title in enumerate(tqdm(wiki_titles)):
        doc_id_vector = f"{wiki_title}_vector"
        doc_id_summary = f"{wiki_title}_summary"
        corpus[doc_id_vector] = city_descs_dict[doc_id_vector]
        corpus[doc_id_summary] = city_descs_dict[doc_id_summary]
    
        # generate questions for semantic search
        node_parser = SimpleNodeParser.from_defaults()
        nodes = node_parser.get_nodes_from_documents(city_docs[wiki_title])
        
        dataset_generator = DatasetGenerator(
            nodes,
            service_context=gpt_35_context,
            num_questions_per_chunk=num_vector_qs_per_node,
        )
        doc_questions = dataset_generator.generate_questions_from_nodes(num=len(nodes) * num_vector_qs_per_node)
        for query_idx, doc_question in enumerate(doc_questions):
            query_id = f"{wiki_title}_{query_idx}"
            relevant_docs[query_id] = [doc_id_vector]
            queries[query_id] = doc_question
            
        # generate questions for summarization
        base_q = f"Give me a summary of {wiki_title}"
        fmt_prompt = summary_q_prompt.format(
            num_vary=num_summary_qs,
            base_question=base_q,
        )
        raw_response = llm.complete(fmt_prompt)
        raw_lines = str(raw_response).split("\n")
        doc_summary_questions = [l for l in raw_lines if l != ""]
        print(f"[{idx}] Original Question: {base_q}")
        print(f"[{idx}] Generated Question Variations: {doc_summary_questions}")
        for query_idx, doc_summary_question in enumerate(doc_summary_questions):
            query_id = f"{wiki_title}_{query_idx}"
            relevant_docs[query_id] = [doc_id_summary]
            queries[query_id] = doc_summary_question 

    return EmbeddingQAFinetuneDataset(queries=queries, corpus=corpus, relevant_docs=relevant_docs)

In [None]:
dataset = generate_dataset(
    wiki_titles, 
    city_descs_dict, 
    llm, 
    summary_q_prompt, 
    num_vector_qs_per_node=4,
    num_summary_qs=5
)

  0%|          | 0/8 [00:00<?, ?it/s]

[0] Original Question: Give me a summary of Toronto
[0] Generated Question Variations: ['What are the key highlights of Toronto?', "Can you provide a brief overview of Toronto's history, culture, and attractions?", 'In a few sentences, can you summarize what makes Toronto unique?', 'Tell me about the main features and characteristics of Toronto.', "Can you give me a concise summary of Toronto's economy, population, and geography?"]
[1] Original Question: Give me a summary of Seattle
[1] Generated Question Variations: ['What are the key highlights of Seattle?', "Can you provide a brief overview of Seattle's history, culture, and attractions?", 'In a few sentences, summarize what makes Seattle unique.', "Can you give me a quick rundown of Seattle's top industries and landmarks?", 'Write a concise summary of Seattle for someone who has never been there.']


In [62]:
# dataset.queries

In [63]:
# [optional] save
dataset.save_json("dataset.json")

In [64]:
# [optional] load
dataset = EmbeddingQAFinetuneDataset.from_json("dataset.json")

In [72]:
import random

def split_train_val_by_query(dataset, split=0.7):
    """Split dataset by queries."""
    query_ids = list(dataset.queries.keys())
    query_ids_shuffled = random.sample(query_ids, len(query_ids))
    split_idx = int(len(query_ids) * split)
    train_query_ids = query_ids_shuffled[:split_idx]
    eval_query_ids = query_ids_shuffled[split_idx:]

    train_queries = {qid: dataset.queries[qid] for qid in train_query_ids}
    eval_queries = {qid: dataset.queries[qid] for qid in eval_query_ids}

    train_rel_docs = {qid: dataset.relevant_docs[qid] for qid in train_query_ids}
    eval_rel_docs = {qid: dataset.relevant_docs[qid] for qid in eval_query_ids}

    train_dataset = EmbeddingQAFinetuneDataset(
        queries=train_queries,
        corpus=dataset.corpus,
        relevant_docs=train_rel_docs
    )
    eval_dataset = EmbeddingQAFinetuneDataset(
        queries=eval_queries,
        corpus=dataset.corpus,
        relevant_docs=eval_rel_docs
    )
    return train_dataset, eval_dataset

In [73]:
train_dataset, eval_dataset = split_train_val_by_query(dataset, split=0.7)

In [78]:
# eval_dataset.queries

## Fine-tuning Embeddings

In this section we try to fine-tune embeddings.

In [112]:
# generate embeddings dataset 
from llama_index.finetuning import SentenceTransformersFinetuneEngine 

In [122]:
finetune_engine = SentenceTransformersFinetuneEngine(
    train_dataset,
    model_id="BAAI/bge-small-en",
    model_output_path="test_model2",
    val_dataset=eval_dataset,
    epochs=10
)

In [123]:
finetune_engine.finetune()

Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/21 [00:00<?, ?it/s]

Iteration:   0%|          | 0/21 [00:00<?, ?it/s]

Iteration:   0%|          | 0/21 [00:00<?, ?it/s]

Iteration:   0%|          | 0/21 [00:00<?, ?it/s]

Iteration:   0%|          | 0/21 [00:00<?, ?it/s]

Iteration:   0%|          | 0/21 [00:00<?, ?it/s]

Iteration:   0%|          | 0/21 [00:00<?, ?it/s]

Iteration:   0%|          | 0/21 [00:00<?, ?it/s]

Iteration:   0%|          | 0/21 [00:00<?, ?it/s]

Iteration:   0%|          | 0/21 [00:00<?, ?it/s]

In [124]:
ft_embed_model = finetune_engine.get_finetuned_model()

In [125]:
ft_embed_model

HuggingFaceEmbedding(model_name='test_model2', embed_batch_size=10, callback_manager=<llama_index.callbacks.base.CallbackManager object at 0x2f7ababc0>, tokenizer_name='test_model2', max_length=512, pooling='cls', normalize='True', query_instruction=None, text_instruction=None, cache_folder=None)

## Run Evaluations

In this section we evaluate the quality of our fine-tuned embedding model vs. our base model in selecting the right choice.

We plug both into our `EmbeddingSelector` abstraction.

We also compare against a base `LLMSingleSelector` using GPT-4. 

In [126]:
# define baseline embedding model
from llama_index.embeddings import resolve_embed_model

base_embed_model = resolve_embed_model("local:BAAI/bge-small-en")

In [127]:
from llama_index.selectors import EmbeddingSingleSelector, LLMSingleSelector

ft_selector = EmbeddingSingleSelector.from_defaults(
    embed_model=ft_embed_model
)
base_selector = EmbeddingSingleSelector.from_defaults(
    embed_model=base_embed_model
)

In [128]:
import numpy as np

def run_evals(eval_dataset, selector, choices, choice_to_id_dict):
    # we just measure accuracy
    eval_pairs = eval_dataset.query_docid_pairs
    matches = []
    for query, relevant_doc_ids in tqdm(eval_pairs):
        result = selector.select(choices, query)
        # assume single selection for now
        pred_doc_id = choice_to_id_dict[result.inds[0]]
        gt_doc_id = relevant_doc_ids[0]
        matches.append(gt_doc_id == pred_doc_id)
    return np.array(matches)

In [None]:
base_selector.select(choices, list(eval_dataset.queries.values())[0])

In [129]:
ft_matches = run_evals(eval_dataset, ft_selector, choices, choice_to_id_dict)
# base_matches = run_evals(eval_dataset, base_selector, choices, choice_to_id_dict)

  0%|          | 0/90 [00:00<?, ?it/s]

Useful for questions related to specific aspects of Toronto (e.g. the history, arts and culture, sports, demographics, or more).
Useful for questions related to specific aspects of Toronto (e.g. the history, arts and culture, sports, demographics, or more).
Useful for any requests that require a holistic summary of EVERYTHING about Toronto. For questions about more specific sections, please use the vector_tool.
Useful for any requests that require a holistic summary of EVERYTHING about Toronto. For questions about more specific sections, please use the vector_tool.
Useful for questions related to specific aspects of Seattle (e.g. the history, arts and culture, sports, demographics, or more).
Useful for questions related to specific aspects of Seattle (e.g. the history, arts and culture, sports, demographics, or more).
Useful for any requests that require a holistic summary of EVERYTHING about Seattle. For questions about more specific sections, please use the vector_tool.
Useful for an

In [95]:
np.mean(base_matches)

0.1111111111111111

In [130]:
np.mean(ft_matches)

0.24444444444444444

In [109]:
# also try LLM
from llama_index.llms import OpenAI

eval_llm = OpenAI(model="gpt-3.5-turbo")

llm_selector = LLMSingleSelector.from_defaults(
    service_context=ServiceContext.from_defaults(llm=eval_llm)
)

In [110]:
llm_matches = run_evals(eval_dataset, llm_selector, choices, choice_to_id_dict)

  0%|          | 0/90 [00:00<?, ?it/s]

In [111]:
np.mean(llm_matches)

0.7111111111111111

## Plug into Router

We plug this into our `RouterQueryEngine` as an `EmbeddingSelector` (by default, an `LLMSingleSelector` is used in our router query engine).

In [None]:
from llama_index.query_engine import RouterQueryEngine
from llama_index import SummaryIndex, VectorStoreIndex

# define indexes/tools for wikipedia entries
tools = []
for idx, wiki_title in enumerate(wiki_titles):
    doc_id_vector = f"{wiki_title}_vector"
    doc_id_summary = f"{wiki_title}_summary"

    vector_index = VectorStoreIndex.from_documents(city_docs[wiki_title])
    summary_index = SummaryIndex.from_documents(city_docs[wiki_title])
    vector_tool = QueryEngineTool.from_defaults(
        query_engine=vector_index.as_query_engine(),
        description=city_descs[doc_id_vector]
    )
    summary_tool = QueryEngineTool.from_defaults(
        query_engine=summary_index.as_query_engine(),
        description=city_descs[doc_id_summary]
    )
    tools.extend([vector_tool, summary_tool])

In [None]:
router_query_engine = RouterQueryEngine.from_defaults(
    selector=ft_selector.from_defaults(),
    query_engine_tools=tools
)

In [None]:
response = router_query_engine.query("Tell me more about the sports teams in Toronto")

In [None]:
response = router_query_engine.query("Can you tell me more about Lisbon?")