# Knowledge Distillation For Fine-Tuning A GPT-3.5 Judge

There has been recent research that demonstrated GPT-4's ability to closely align to human judges when evaluating LLM generated texts (e.g., see [[1]](https://arxiv.org/abs/2306.05685), [[2]](https://arxiv.org/abs/2303.16634)). In this notebook, we demonstrate how to use the `llama_index` library to distill knowledge from GPT-4 to GPT-3.5 so that the smaller GPT-3.5 becomes closer to GPT-4 performance; and by proxy, closer to human judges.

To do so, we take the following steps:

1. Generate datasets: `train` and `test`
2. Perform knowledge distillation (using `train`)
3. Evaluate the distilled model  on `test`

## 0 Prompt Templates & Asyncio Event Loop

In [None]:
PROMPTS = {
    "QUESTION_GEN": (
        "You are a Teacher/ Professor. Your task is to setup "
        "a quiz/examination. Using the provided context, formulate "
        "a single question that captures an important fact from the "
        "context. Restrict the question to the context information provided."
    )
}

In [None]:
import nest_asyncio

nest_asyncio.apply()

## 1 Generate datasets: `train` and `test`

We should not lose sight on the ultimate goal here, which is to build an LLM judge that closely matches to human judges when evaluating LLM-generated texts. The work we need to do in this step, therefore, is to build a set of generated texts that our LLM judges will judge. More specifically, we will follow the "pairwise comparison" evaluation design pattern, where one text generation is passed to an LLM judge that is subsequently prompted to assign a score between 0 and 1 (higher is better).

To generate a varied set of texts we'll use the following LLM text-generators:
1. HuggingFace: Vicuna-13B
2. HuggingFace: Mistral-7B
3. HuggingFace: Falcon-7B

The generation task we ask of each of these models will be to generate an abstractive answer to question when provided relevant context (i.e., RAG).

### Using `DatasetGenerator` to build `train` and `test`

The specific procedure we will use here involves generating questions against a set of chunks of a given `Document`. With the `<question, chunk>` pairs in hand, (for which we can merely treat as a "simulated" retrieval), we pass this information to the three LLM generators and prompt them each to generate an answer.

Hang tight, we're almost there (sort of). Since we want to distill GPT-4 abilities for this task to GPT-3.5, we now need to generate GPT-4 judgements on the generated answers. To do this, we will pass the `<question, answer A, answer B>` (where `A` and `B` represent answers from any two of the LLM text-generators) as context to the GPT-4 judge and prompt it to decide the better answer of the two.

With all of that we can now build a `dataset` that looks like the one below.
| question | context-answer-A-answer-B | gpt-4-evaluation |
|----------|---------------------------|------------------|
| ...      | ...                       | ...              |

And finally, to get `train` and `test` we will simply randomly shuffle `dataset` and split it using a 70/30 ratio. (Phew!)

With all that out of the way, let's spring into action. First, we will download the reference pdf document and create the set of questions against it.

In [None]:
# Download the pdf document — Uncomment the line of code below
# !curl https://www.ipcc.ch/report/ar6/wg2/downloads/report/IPCC_AR6_WGII_Chapter03.pdf --output IPCC_AR6_WGII_Chapter03.pdf

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 20.7M  100 20.7M    0     0   619k      0  0:00:34  0:00:34 --:--:--  648k  441k      0  0:00:48  0:00:02  0:00:46  441k     0   611k      0  0:00:34  0:00:24  0:00:10  635k616k      0  0:00:34  0:00:32  0:00:02  632k


In [None]:
import random
from llama_index import SimpleDirectoryReader, ServiceContext

# load a document
documents = SimpleDirectoryReader(
    input_files=["IPCC_AR6_WGII_Chapter03.pdf"]
).load_data()

# Shuffle the documents
random.seed(42)
random.shuffle(documents)

In [None]:
# generate questions against chunks
from llama_index.evaluation import DatasetGenerator
from llama_index.llms import OpenAI

# set context for llm provider
gpt_35_context = ServiceContext.from_defaults(
    llm=OpenAI(model="gpt-3.5-turbo", temperature=0.3)
)

# instantiate a DatasetGenerator
dataset_generator = DatasetGenerator.from_documents(
    documents,
    question_gen_query=PROMPTS["QUESTION_GEN"],
    service_context=gpt_35_context,
)

In [None]:
# use DatasetGenerator to create questions from nodes
questions = dataset_generator.generate_questions_from_nodes(num=100)

# let's take a look at a few of these
for q in questions[:5]:
    print(q)

What are some approaches used to assess ecological responses to multiple climate-induced drivers?
Question: What is the projected decline in marine animal biomass with warming under SSP1-2.6 and SSP5-8.5 by 2080-2099 relative to 1995-2014?
Question: What is the projected impact of tropicalisation on species richness at local to regional scales?
What are the two Shared Socioeconomic Pathways (SSPs) under which the ensemble projections of global changes in phytoplankton phenology were made?
Question: According to the context information, what is the title of the paper published in 2017 that provides biogeochemical protocols and diagnostics for the CMIP6 Ocean Model Intercomparison Project (OMIP)?


Now that we have the questions, the next step is to generate answers using the three LLM text-generators: Vicuna, Mistral, and Falcon.

In [None]:
# Create vector index
from llama_index import VectorStoreIndex
from llama_index.indices.vector_store.retrievers import VectorIndexRetriever

index = VectorStoreIndex.from_documents(documents=documents)

retriever = VectorIndexRetriever(  # what embeddings are being used?
    index=index,
    node_ids=list(index.index_struct.nodes_dict.values()),
    similarity_top_k=2,
)

In [None]:
from llama_index.query_engine.retriever_query_engine import (
    RetrieverQueryEngine,
)
from llama_index.llms import Replicate, OpenAI

# define our llm-generators (RAGs)

# Vicuna
vicuna_context = ServiceContext.from_defaults(
    llm=Replicate(
        model="replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b",
        temperature=0.3,
        context_window=2048,  # to use refine
    )
)

vicuna_query_engine = RetrieverQueryEngine.from_args(
    retriever=retriever, service_context=vicuna_context
)

# define our llm judges (also student/teacher models)

In [None]:
vicuna_context

ServiceContext(llm_predictor=LLMPredictor(system_prompt=None, query_wrapper_prompt=None), prompt_helper=PromptHelper(context_window=2048, num_output=256, chunk_overlap_ratio=0.1, chunk_size_limit=None, separator=' '), embed_model=OpenAIEmbedding(model_name='text-embedding-ada-002', embed_batch_size=10, callback_manager=<llama_index.callbacks.base.CallbackManager object at 0x16279af50>, deployment_name=None, additional_kwargs={}, api_key='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', api_type='open_ai', api_base='https://api.openai.com/v1', api_version=''), node_parser=SimpleNodeParser(text_splitter=SentenceSplitter(chunk_size=1024, chunk_overlap=20, separator=' ', paragraph_separator='\n\n\n', secondary_chunking_regex='[^,.;。？！]+[,.;。？！]?', chunking_tokenizer_fn=<function split_by_sentence_tokenizer.<locals>.split at 0x107e3d900>, callback_manager=<llama_index.callbacks.base.CallbackManager object at 0x16279af50>, tokenizer=functools.partial(<bound method Encoding.encode of <En

In [None]:
response = vicuna_query_engine.query(questions[1])

In [None]:
questions[1]

'Question: What is the projected decline in marine animal biomass with warming under SSP1-2.6 and SSP5-8.5 by 2080-2099 relative to 1995-2014?'

In [None]:
response

Response(response='\u200b', source_nodes=[NodeWithScore(node=TextNode(id_='d4e0bc38-a614-4463-a808-0a51923ed54d', embedding=None, metadata={'page_label': '446', 'file_name': 'IPCC_AR6_WGII_Chapter03.pdf'}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={<NodeRelationship.SOURCE: '1'>: RelatedNodeInfo(node_id='dd617b75-f8ef-478e-a4cc-e0d44fcbaaaa', node_type=<ObjectType.DOCUMENT: '4'>, metadata={'page_label': '446', 'file_name': 'IPCC_AR6_WGII_Chapter03.pdf'}, hash='7bab188793a5813eb45b7cd57170de4469734f90aa82b1530b37521ec2dfc3e9'), <NodeRelationship.PREVIOUS: '2'>: RelatedNodeInfo(node_id='f75150d0-c493-4ec6-8149-a4a04a7e75bd', node_type=<ObjectType.TEXT: '1'>, metadata={'page_label': '446', 'file_name': 'IPCC_AR6_WGII_Chapter03.pdf'}, hash='dba5548dd2fc08ba7df8984d813c79dbc367ce732d3e1819800b868a7f770cf3')}, hash='402ade5489b10a61bc5409d2842ea92d80f37853167966a320bdd16e334bf831', text='The new CMIP6 ESM ensemble projects \na decline in global zooplankton

In [None]:
# create our dataset, and split into train and test

## 2 Perform knowledge distillation

Okay, it's now time to distill some knowledge from GPT-4 to GPT-3.5 To do this, we will make use of `OpenAIFinetuneEngine` class of `llama_index`. 