# Fine-Tuning a Reranker using Cohere

### This demo will show you how to:
1. Generate synthetic data using DSPy
2. Export all your data from your Weaviate instance
3. Steps to fine-tune a reranker using Cohere
4. Query in Weaviate using your fine-tuned reranker model

#### Note: To fine-tune a model in Cohere, you need to have a minimum of 256 unique queries with at least 1 relevant passage per query. If you already have a dataset with query + relevant passages, you can skip to the end of the notebook!

## Connect to Weaviate Instance

In [26]:
import weaviate
import json
import os 

client = weaviate.Client(
    url = "WEAVIATE_URL",  # Replace with your cluster url
    auth_client_secret=weaviate.AuthApiKey(api_key="AUTH_KEY"),  # Replace w/ your Weaviate instance API key
    additional_headers = {
        "X-Cohere-Api-Key": "API_KEY" # Replace with your inference API key
    }
)

            Consider upgrading to the new and improved v4 client instead!
            See here for usage: https://weaviate.io/developers/weaviate/client-libraries/python
            


## Import Libraries

In [2]:
import logging
import sys

import nest_asyncio

nest_asyncio.apply()

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

In [3]:
import random
import uuid
import tqdm
from typing import List, Any
import pandas as pd
import matplotlib.pyplot as plt
import asyncio
import weaviate

INFO:numexpr.utils:Note: NumExpr detected 10 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
Note: NumExpr detected 10 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.
NumExpr defaulting to 8 threads.


## Load in Data

In [20]:
import re

def chunk_list(lst, chunk_size):
    """Break a list into chunks of the specified size."""
    return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]

def split_into_sentences(text):
    """Split text into sentences using regular expressions."""
    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
    return [sentence.strip() for sentence in sentences if sentence.strip()]

def read_and_chunk_index_files(main_folder_path):
    """Read index.md files from subfolders, split into sentences, and chunk every 5 sentences."""
    blog_chunks = []
    for folder_name in os.listdir(main_folder_path):
        subfolder_path = os.path.join(main_folder_path, folder_name)
        if os.path.isdir(subfolder_path):
            index_file_path = os.path.join(subfolder_path, 'index.mdx')
            if os.path.isfile(index_file_path):
                with open(index_file_path, 'r', encoding='utf-8') as file:
                    content = file.read()
                    sentences = split_into_sentences(content)
                    sentence_chunks = chunk_list(sentences, 5)
                    sentence_chunks = [' '.join(chunk) for chunk in sentence_chunks]
                    blog_chunks.extend(sentence_chunks)
    return blog_chunks

# Example usage
main_folder_path = './blog'
blogs = read_and_chunk_index_files(main_folder_path)

In [6]:
print(len(blogs))

1182


## Define Schema

If you need to reset your schema and delete objects in a collection, run:
`client.schema.delete_all()` or `client.schema.delete_class("Blogs")`

In [28]:
schema = {
   "classes": [
       {
           "class": "Blogs",
           "description": "Weaviate blogs",
           "vectorizer": "text2vec-cohere",
           "properties": [
               {
                   "name": "content",
                   "dataType": ["text"],
                   "description": "Content from the blogs.",
               },
               {
                   "name": "synthetic_query",
                   "dataType": ["text"],
                   "description": "Synthetic query generated from a LM."
               }
           ]
       }      
   ]
}
    
client.schema.create(schema)

## Import 

In [None]:
%pip install dspy-ai > /dev/null

#### To generate the synthetic queries, we will use DSPy's signature and chain-of-thought module. 

In [7]:
import dspy
import openai
openai.api_key = "sk-key"

class WriteQuery(dspy.Signature):
    """Write a query that this document would have the answer to."""

    document = dspy.InputField(desc="A document containing information.") 
    query = dspy.OutputField(desc="A short question uniquely answered by the document.")

gpt4T = dspy.OpenAI(model='gpt-4-1106-preview', max_tokens=1000, model_type='chat')

for idx, blog_chunk in enumerate(blogs[300:]):
    if idx > 400: # only need 400
        break
    print(idx)
    with dspy.context(lm=gpt4T):
        llm_query = dspy.ChainOfThought(WriteQuery)(document=blog_chunk)
    print(llm_query)
    data_properties = {
        "content": blog_chunk,
        "synthetic_query": llm_query.query
    }
    print(f"{data_properties}\n")
    client.data_object.create(data_properties, "Blogs")

0
Prediction(
    rationale="Document: A document containing information.\nReasoning: Let's think step by step in order to understand the relationship between recall, heap usage, and parameter sets as depicted in the charts. The document seems to describe the trade-offs between memory usage and the accuracy of a search algorithm, as influenced by the choice of parameter sets. The charts on the left illustrate how recall (the ability to retrieve relevant items) is affected by the amount of heap memory used. The charts on the right seem to compare heap usage with different parameter sets, suggesting that certain sets require more memory but may result in a larger and more accurate graph. The document also hints at the possibility of adjusting the level of data compression, which would further affect these relationships. Therefore, the query should aim to extract information about how these variables interact.",
    query='How does heap usage affect recall and the accuracy of search resul

KeyboardInterrupt: 

#### Here is one example of the chain-of-thought module in DSPy. It is taking my initial signature (prompt) and putting the first blog chunk into the prompt.

In [35]:
with dspy.context(lm=gpt4T):
    dspy.ChainOfThought(WriteQuery)(document=blogs[0]).query
gpt4T.inspect_history(n=1)





Write a query that this document would have the answer to.

---

Follow the following format.

Document: A document containing information.
Reasoning: Let's think step by step in order to ${produce the query}. We ...
Query: A short question uniquely answered by the document.

---

Document: --- title: Combining LangChain and Weaviate slug: combining-langchain-and-weaviate authors: [erika] date: 2023-02-21 tags: ['integrations'] image: ./img/hero.png description: "LangChain is one of the most exciting new tools in AI. It helps overcome many limitations of LLMs, such as hallucination and limited input lengths." --- ![Combining LangChain and Weaviate](./img/hero.png) Large Language Models (LLMs) have revolutionized the way we interact and communicate with computers. These machines can understand and generate human-like language on a massive scale. LLMs are a versatile tool that is seen in many applications like chatbots, content creation, and much more. Despite being a powerful tool, 

## Export the data from your Weaviate instance

To fine-tune the model, we need to export our data and upload it to Cohere's reranker. 

In [36]:
'''
This example will show you how to get all of your data
out of Weaviate and into a JSON file using the Cursor API!
'''
import json
import time
start = time.time()

# Step 1 - Get the UUID of the first object inserted into Weaviate

get_first_object_weaviate_query = """
{
  Get {
    Blogs {
      _additional {
        id
      }
    }
  }
}
"""

results = client.query.raw(get_first_object_weaviate_query)
uuid_cursor = results["data"]["Get"]["Blogs"][0]["_additional"]["id"]

# Step 2 - Get the Total Objects in Weaviate

total_objs_query = """
{
    Aggregate {
        Blogs {
            meta {
                count
            }
        }
    }
}
"""

results = client.query.raw(total_objs_query)
total_objects = results["data"]["Aggregate"]["Blogs"][0]["meta"]["count"]

# Step 3 - Iterate through Weaviate with the Cursor
increment = 50
data = []
for i in range(0, total_objects, increment):
    results = (
        client.query.get("Blogs", ["content", "synthetic_query"])
        .with_additional(["id"])
        .with_limit(50)
        .with_after(uuid_cursor)
        .do()
    )["data"]["Get"]["Blogs"]
    # extract data from result into JSON
    for result in results:
        if len(result["synthetic_query"]) < 5:
            continue
        new_obj = {}
        for key in result.keys():
            if key == "_additional":
                continue
            if key == "synthetic_query":
                new_obj["query"] = result[key]
            if key == "content":
                new_obj["relevant_passages"] = [result[key]]
        data.append(new_obj)
    # update uuid cursor to continue the loop
    # we have just exited a loop where result holds the last obj
    uuid_cursor = result["_additional"]["id"]

# save JSON
file_path = "my_data.jsonl"
with open(file_path, 'w') as jsonl_file:
    for item in data:
        jsonl_file.write(json.dumps(item) + '\n')

print("Your data is out of Weaviate!")
print(f"Extracted {total_objects} in {time.time() - start} seconds.")

Your data is out of Weaviate!
Extracted 302 in 0.6464710235595703 seconds.


In [10]:
'''
Grab objects 303 to 405 for the validation set
'''
import weaviate
import json
import time
start = time.time()

# Step 1 - Get the UUID of the first object inserted into Weaviate

get_first_object_weaviate_query = """
{
  Get {
    Blogs {
      _additional {
        id
      }
    }
  }
}
"""

results = client.query.raw(get_first_object_weaviate_query)
uuid_cursor = results["data"]["Get"]["Blogs"][0]["_additional"]["id"]

# Step 2 - Get the Total Objects in Weaviate

total_objs_query = """
{
    Aggregate {
        Blogs {
            meta {
                count
            }
        }
    }
}
"""

results = client.query.raw(total_objs_query)
total_objects = results["data"]["Aggregate"]["Blogs"][0]["meta"]["count"]

# Step 3 - Iterate through Weaviate with the Cursor
start_object_index = 302
end_object_index = 405    
increment = 50            
data = []

first_batch_size = min(increment, end_object_index - start_object_index + 1)

for i in range(start_object_index, end_object_index, increment):
    # Adjust the limit based on the range still needed to cover
    current_batch_size = min(increment, end_object_index - i + 1)
    
    results = (
        client.query.get("Blogs", ["content", "synthetic_query"])
        .with_additional(["id"])
        .with_limit(current_batch_size)
        .with_after(uuid_cursor)
        .do()
    )["data"]["Get"]["Blogs"]
    # Extract data from result into JSON
    for result in results:
        if len(result["synthetic_query"]) < 5:
            continue
        new_obj = {}
        for key in result.keys():
            if key == "_additional":
                continue
            if key == "synthetic_query":
                new_obj["query"] = result[key]
            if key == "content":
                new_obj["relevant_passages"] = [result[key]]
        data.append(new_obj)
        uuid_cursor = result["_additional"]["id"]

# save JSON
file_path = "validation.jsonl"
with open(file_path, 'w') as jsonl_file:
    for item in data:
        jsonl_file.write(json.dumps(item) + '\n')

print("Your data is out of Weaviate!")
print(f"Extracted objects from 303 to 405 in {time.time() - start} seconds.")

Your data is out of Weaviate!
Extracted objects from 303 to 405 in 0.21405792236328125 seconds.


## Re-Index Data 

In order to use our fine-tuned reranker, we will need to upload our data again to a new collection and add the `model_id`.

### New collection with the same data

In [16]:
schema = {
   "classes": [
       {
           "class": "BlogsFineTuned",
           "description": "Weaviate blogs",
           "vectorizer": "text2vec-cohere",
           "moduleConfig": {
                "reranker-cohere": {
                    "model": "model_id" # grab the model_id from Cohere
                }
           },
           "properties": [
               {
                   "name": "content",
                   "dataType": ["text"],
                   "description": "Content from the blogs.",
               }
           ]
       }      
   ]
}
    
client.schema.create(schema)

In [None]:
schema = {
        "class": "BlogsFineTuned",
        "description": "Weaviate blogs",
        "vectorizer": "text2vec-cohere",
        "moduleConfig": {
            "reranker-cohere": {
                "model": "reranker model path"
        }
    }
}      

### Upload data (same as above)

In [27]:
for blog in blogs:
    data_properties = {
        "content": blog
    }
    client.data_object.create(
        data_object = data_properties,
        class_name = "BlogsFineTuned"
    )