
# Using Ragas for evaluating Retrieval Augmented Generation Chains on Databricks

Note: Validated with Databricks runtimes: 14.3 ML LTS and 13.3 ML LTS

## Introduction

Ragas is a framework that helps you evaluate your Retrieval Augmented Generation (RAG) pipelines. RAG denotes a class of LLM applications that use external data to augment the LLM’s context. There are existing tools and frameworks that help you build these pipelines but evaluating it and quantifying your pipeline performance can be hard. This is where Ragas (RAG Assessment) comes in.

## Key Features

- **Dataset Generation**: Ragas allows for the creation of diverse and comprehensive datasets that can be used to test the performance of RAG models.
- **Customizable**: Users can tailor the datasets to include specific types of queries and documents, ensuring that the testing scenarios are relevant to their use cases.
- **Integration with Databricks**: Ragas can be seamlessly integrated with Databricks, enabling users to leverage the powerful data processing and machine learning capabilities of the platform.

## Benefits

- **Improved Model Performance**: By using well-constructed testing datasets, users can identify weaknesses in their RAG models and make necessary adjustments to improve performance.
- **Efficiency**: Automating the dataset generation process saves time and resources, allowing data scientists to focus on model development and optimization.
- **Scalability**: Ragas can handle large volumes of data, making it suitable for enterprise-level applications.

In [0]:
%pip install langchain langchain_community ragas databricks-vectorsearch llama_index

In [0]:
%pip show langchain langchain_community databricks-vectorsearch

In [0]:
dbutils.library.restartPython()

In [0]:
# import subprocess
# subprocess.run(["python", "-c", "import dbutils; dbutils.library.restartPython()"])

In [0]:
from databricks.vector_search.client import VectorSearchClient

vsc = VectorSearchClient()

In [0]:
%run "../.././Tools/helpers"

Update the following variables with your resource names:
- vector_search_endpoint_name
- vector_index_name
- uc_save_path (catalog.schema)

In [0]:
vector_search_endpoint_name = {vector_search_endpoint_name}
vector_index_name = {vector_index_name}
uc_save_path = {catalog.schema}
embeddings_model = "databricks-gte-large-en"
llm_model = "databricks-meta-llama-3-1-70b-instruct"

In [0]:
import mlflow

mlflow.autolog(disable=True)

In [0]:
def display_dataset(ds):
  ds_pd = ds.to_pandas()
  ds_df = spark.createDataFrame(ds_pd)
  display(ds_df)

# Get the documents using langchain

In [0]:
from pyspark.sql.types import StructType, StructField, StringType, LongType
from pyspark.sql.functions import col, monotonically_increasing_id
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader

schema = StructType([StructField('page_content', StringType(), True), StructField('type', StringType(), True), StructField('id', LongType(), True)])

loader = WebBaseLoader("https://mlflow.org/docs/latest/index.html")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)

df = spark.createDataFrame(docs, schema).drop(col("metadata")).withColumn("id", monotonically_increasing_id())

display(df)

try:
  df.write.option("mergeSchema", "true").mode("overwrite").format("delta").saveAsTable(f"{uc_save_path}.raw_mlflow_docs")
  print(f"Successfully saved table: {uc_save_path}.raw_mlflow_docs!")
except:
  print(f"Failed to write table: {uc_save_path}.raw_mlflow_docs.")

In [0]:

alter_table_cdf = f"ALTER TABLE {uc_save_path}.raw_mlflow_docs SET TBLPROPERTIES ('delta.enableChangeDataFeed' = 'true')"
 
if(spark.sql(alter_table_cdf)):
  print(f"Table [{uc_save_path}.raw_mlflow_docs] updated successfully!")

# Set up LLM configs 

In [0]:
from langchain_community.chat_models import ChatDatabricks
from langchain_core.messages import HumanMessage
from mlflow.deployments import get_deploy_client

generator_llm = ChatDatabricks(
    target_uri="databricks",
    endpoint= llm_model,
    temperature=0.1,
)

In [0]:
critic_llm = ChatDatabricks(
    target_uri="databricks",
    endpoint= llm_model,
    temperature=0.1,
)

In [0]:
from langchain_community.embeddings import DatabricksEmbeddings

embeddings = DatabricksEmbeddings(endpoint="databricks-gte-large-en")

In [0]:
from ragas.testset.generator import TestsetGenerator

generator = TestsetGenerator.from_langchain(
    generator_llm,
    critic_llm,
    embeddings
)

# Make the index for testing

In [0]:
from databricks.vector_search.client import VectorSearchClient

vsc = VectorSearchClient()

In [0]:
create_vector_search_if_not_exists(vector_search_endpoint_name, vsc)

In [0]:
create_vector_index_if_not_exists(vector_search_endpoint_name, vector_index_name, uc_save_path, embeddings_model, vsc)

In [0]:
vs_index = vsc.get_index(endpoint_name= vector_search_endpoint_name, index_name= vector_index_name)

vs_index.describe()

# Build the QA Chain

In [0]:
from langchain.vectorstores import DatabricksVectorSearch
from langchain.embeddings import DatabricksEmbeddings


def get_retriever(persist_dir: str = None):
    vs_index = vsc.get_index(
        endpoint_name= vector_search_endpoint_name,
        index_name= vector_index_name
    )
    # Create the retriever
    vectorstore = DatabricksVectorSearch(
        vs_index, text_column="page_content"
    )
    return vectorstore.as_retriever()

# test our retriever
retriever = get_retriever()

In [0]:
# If running a Databricks notebook attached to an interactive cluster in "single user"
# or "no isolation shared" mode, you only need to specify the endpoint name to create
# a `Databricks` instance to query a serving endpoint in the same workspace.

# Otherwise, you can manually specify the Databricks workspace hostname and personal access token
# or set `DATABRICKS_HOST` and `DATABRICKS_TOKEN` environment variables, respectively.
# You can set those environment variables based on the notebook context if run on Databricks

import os
from langchain_community.chat_models.databricks import ChatDatabricks
from langchain_core.messages import HumanMessage, SystemMessage

# Need this for job run: 
# os.environ['DATABRICKS_URL'] = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().getOrElse(None) 
# os.environ['DATABRICKS_TOKEN'] = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)

def transform_input(**request):
  request["messages"] = [
    {
      "role": "user",
      "content": request["prompt"]
    }
  ]
  del request["prompt"]
  return request

# databricks-meta-llama-3-70b-instruct or databricks
llm = ChatDatabricks(endpoint=llm_model, transform_input_fn=transform_input, extra_params={"temperature": 0.1, "max_tokens":512})

#if you want answers to generate faster, set the number of tokens above to a smaller number
prompt = "What is mlflow?"

output = llm.invoke(prompt)

In [0]:
from langchain import PromptTemplate
from langchain.chains import RetrievalQA

def build_qa_chain():
  
  template = """You are an expert in machine learning engineering and machine learning operations. 

  ### Instruction:
  Use only information in the following paragraphs to answer the question. Explain the answer with reference to these paragraphs. If you don't know, say that you do not know.

  {context}
  
  {question}

  ### Response:
  """
  prompt = PromptTemplate(input_variables=['context', 'question'], template=template)

  qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever= retriever,
    return_source_documents=True,
    chain_type_kwargs={
        "verbose": False,
        "prompt": prompt
    }
  )
  
  # Set verbose=True to see the full prompt:
  return qa_chain

In [0]:
qa_chain = build_qa_chain()

In [0]:
from mlflow.deployments import set_deployments_target

set_deployments_target("databricks")

In [0]:
question = "What are the different components of MLflow related to Model Management?"
result = qa_chain.invoke({"query": question})

In [0]:
# Check the result of the query
display(result["result"])

In [0]:
# Check the source document from where we 
for item in result["source_documents"]:
  print(item.page_content)

# Create the dataset with Ragas

In [0]:
from ragas.testset.evolutions import simple, reasoning, multi_context

mlflow_dataset = generator.generate_with_langchain_docs(docs, test_size=12, distributions={simple: 0.5, reasoning: 0.25, multi_context: 0.25})


In [0]:
display_dataset(mlflow_dataset)

In [0]:
mlflow_dataset_pd = mlflow_dataset.to_pandas()

# Create Evaluation Dataset

In [0]:
def get_responses(qa_chain, questions, ground_truth):
  from datasets import Dataset

  test_answers = []
  test_contexts = []

  for item in questions:
    temp = []
    result = qa_chain.invoke(item)
    test_answers.append(result["result"])
    for doc in result["source_documents"]:
      temp.append(doc.page_content)
    test_contexts.append(temp)


  dataset_dict = {
        "question": questions,
        "answer": test_answers,
        "contexts": test_contexts,
        "ground_truth": ground_truth,
  }

  result_ds = Dataset.from_dict(dataset_dict)

  return result_ds

In [0]:
result_ds = get_responses(qa_chain, questions=mlflow_dataset_pd['question'].values.tolist(), ground_truth=mlflow_dataset_pd['ground_truth'].values.tolist())

In [0]:
display_dataset(result_ds)


**Answer Relevancy**: 
- Answer relevancy measures how well the generated response addresses the user's query or question. It assesses whether the response is on-topic, relevant, and useful. A high answer relevancy score indicates that the model has successfully understood the user's intent and provided a response that is pertinent to the conversation. This metric is often evaluated using human judgments, where annotators rate the response as relevant or not.  

**Faithfulness**: 
- Faithfulness measures how accurately the generated response reflects the input context or prompt. It assesses whether the model has preserved the essential information, tone, and intent of the original input. A high faithfulness score indicates that the model has successfully captured the essence of the input and generated a response that is consistent with it. This metric is often evaluated using metrics such as ROUGE (Recall-Oriented Understudy for Gisting Evaluation) or METEOR (Metric for Evaluation of Translation with Explicit ORdering).  

**Context Recall**: 
- Context recall measures the proportion of relevant context information that is retained in the generated response. It assesses how well the model recalls specific details, entities, or events mentioned in the input context. A high context recall score indicates that the model has successfully retained important information from the input context and incorporated it into the response. This metric is often evaluated using metrics such as recall@k, which measures the proportion of relevant context information recalled in the top-k responses.  

**Context Precision**: 
- Context precision measures the proportion of relevant context information that is accurately represented in the generated response. It assesses how well the model avoids introducing extraneous or incorrect information that is not present in the input context. A high context precision score indicates that the model has successfully avoided "hallucinating" or introducing irrelevant information and has stuck to the facts present in the input context. This metric is often evaluated using metrics such as precision@k, which measures the proportion of accurate context information in the top-k responses.

In [0]:
from datasets import Dataset
from ragas import evaluate
from ragas.metrics import (
    faithfulness,
    answer_relevancy,
    answer_correctness,
    context_precision, 
    context_recall,
)

metrics = [
    faithfulness,
    answer_relevancy,
    answer_correctness,
    context_precision, 
    context_recall,
]

In [0]:
from ragas import RunConfig

run_config = RunConfig()

run_config.timeout = 360
run_config.wait = 360
run_config.thread_timeout = 360

In [0]:
run_config

In [0]:
from ragas import evaluate

result = evaluate(
    result_ds,
    metrics=metrics,
    llm = critic_llm, 
    embeddings = embeddings, 
    run_config = run_config
)


In [0]:
display_dataset(result)

Databricks visualization. Run in Databricks to view.