# Prerequisite Installation

Ensure you install necessary libraries for this tutorial as following

In [None]:
%pip install vinagent==0.0.4.post7 datasets==4.0.0

# Prepare Data and Tool

We will download a legal case example dataset from huggingface.

In [None]:
from datasets import load_dataset

dataset = load_dataset("joelniklaus/legal_case_document_summarization", split="test")
dataset.to_parquet("data/test.parquet")

In [None]:
import pandas as pd
legal_case = pd.read_parquet("data/test.parquet")
legal_case.head()

To prepare a knowledge base for legal cases. We need to transform each row into a document that comprises `judgement_case, dataset_name, and summary`.

In [None]:
from langchain_core.documents import Document

docs = []
for (i, doc) in legal_case.iterrows():
    doc = Document(
        page_content=doc['judgement'], 
        metadata={
            "judgement_case": i, 
            "dataset_name": doc["dataset_name"],
            "summary": doc["summary"]
        })
    docs.append(doc)

In [None]:
docs[:5]

We will organize vector database by using VectorDatabaseFactory class. Which will chunk each document into many chunks and save into vector database.

In [None]:
from langchain_huggingface import HuggingFaceEmbeddings
from aucodb.vectordb.factory import VectorDatabaseFactory
from aucodb.vectordb.processor import DocumentProcessor
from langchain.text_splitter import RecursiveCharacterTextSplitter

# 1. Initialize embedding model
embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")

# 2. Initialize document processor
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000,
    chunk_overlap=200
)

doc_processor = DocumentProcessor(splitter=text_splitter)

# 3. Initialize vector database factory
db_type = "milvus"  # Supported types: ['chroma', 'faiss', 'milvus', 'pgvector', 'pinecone', 'qdrant', 'weaviate']
vectordb_factory = VectorDatabaseFactory(
    db_type=db_type,
    embedding_model=embedding_model,
    doc_processor=doc_processor
)

# 4. Store documents in the vector database
vectordb_factory.store_documents(docs)

Let's test vector database by a certain query to extract top-5 similar documents.

In [None]:
query = "claimed exemption from Sales Tax under article 286"
top_k = 5
retrieved_docs = vectordb_factory.query(query, top_k)
for (i, doc) in enumerate(retrieved_docs):
    print(f"Document {i}: {doc}")

Write `semantic_search_query` to extract a list of relevant chunks based on the semantic similarity of embedding vectors.

In [None]:
from typing import Any, Dict, List
def semantic_search_query(query: str, top_k: int) -> List[Dict[str, Any]]:
    if vectordb_factory.vectordb.vector_store is None:
        raise ValueError("Vector store not initialized. Store documents first.")

    # Generate embedding for query
    query_vector = vectordb_factory.vectordb.embedding_model.embed_query(query)

    # Perform similarity search
    results = vectordb_factory.vectordb.client.search(
        collection_name=vectordb_factory.vectordb.collection_name,
        data=[query_vector],
        limit=top_k,
        output_fields=["text"],
        search_params={
            "metric_type": vectordb_factory.vectordb.metric_type
        },  # Use consistent metric type
    )[0]
    returned_docs = [(doc.id, doc.distance) for doc in results]
    return returned_docs

results = semantic_search_query(query=query, top_k=5)
results

In other aspect, we need to consider the overlapping percentage of words between query and doc. This metric is another score to increase relevant extracted documents because the semantic similarity score usually high with long sentences.

In [None]:
def exact_match_score(query, doc):
    # Convert strings to sets of words (case-insensitive, removing punctuation)
    query_words = set(query.lower().split())
    doc_words = set(doc.lower().split())
    
    # Calculate intersection of words
    common_words = query_words.intersection(doc_words)
    
    # Avoid division by zero
    if len(query_words) == 0 or len(doc_words) == 0:
        return 0.0
        
    # Calculate score: 0.5 * (|V_q ∩ V_d|/|V_q| + |V_q ∩ V_d|/|V_d|)
    score = 0.5 * (len(common_words) / len(query_words) + len(common_words) / len(doc_words))
    
    return score

def exact_match_search_query(query, docs, top_k: int=5):
    # Calculate scores for all documents
    scores = [(id_doc, exact_match_score(query, doc.page_content)) for (id_doc, doc) in enumerate(docs)]
    
    # Sort by score in descending order
    sorted_scores = sorted(scores, key=lambda x: x[1], reverse=True)
    
    return sorted_scores[:min(top_k, len(docs))]

exact_match_search_query(query=query, docs=docs)

Next step, let's create the class `SearchLegalEngine` that includes these functionalities:

- `_create_legal_cases_data`: Create a legal cases dataset. Each document is a legal record.
- `_initialize_document_processor`: Create a vector factory, which initialize vector database and save a list of documents.
- `exact_match_search_query`: Compute score based on exact matching percentage of words overlapping between query and document.
- `semantic_search_query`: Search a list of scores based on semantic meaning.
- `query_fusion_score`: Combine metrics of exact matching and semantic score.

In [None]:
import pandas as pd
from datasets import load_dataset
from pathlib import Path
from langchain_core.documents import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from typing import Literal
from typing import Any, Dict, List


class SearchLegalEngine:
    def __init__(self, 
            top_k: int=5, 
            temp_data_path: Path = Path("data/test.parquet"),
            db_type: Literal['chroma', 'faiss', 'milvus', 'pgvector', 'pinecone', 'qdrant', 'weaviate'] = "milvus",
            embedding_model: str="BAAI/bge-small-en-v1.5"
        ):
        self.top_k = top_k
        self.embedding_model = HuggingFaceEmbeddings(model_name=embedding_model)
        self.temp_data_path = temp_data_path
        self.db_type = db_type
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200
        )
        self.doc_processor = DocumentProcessor(splitter=self.text_splitter)

    
    def _download_legal_case(self):
        dataset = load_dataset("joelniklaus/legal_case_document_summarization", split="test")
        dataset.to_parquet(self.temp_data_path)
        

    def _create_legal_cases_data(self):
        self._download_legal_case()
        legal_case = pd.read_parquet(self.temp_data_path)

        self.docs = []
        for (i, doc) in legal_case.iterrows():
            doc = Document(
                page_content=doc['judgement'], 
                metadata={
                    "judgement_case": i, 
                    "dataset_name": doc["dataset_name"],
                    "summary": doc["summary"]
                })
            self.docs.append(doc)
        return self.docs

    def _initialize_document_processor(self):
        self.vectordb_factory = VectorDatabaseFactory(
            db_type=self.db_type,
            embedding_model=self.embedding_model,
            doc_processor=self.doc_processor
        )
        self._create_legal_cases_data()        
        self.vectordb_factory.store_documents(self.docs)

    def exact_match_score(self, query, doc):
        # Convert strings to sets of words (case-insensitive, removing punctuation)
        query_words = set(query.lower().split())
        doc_words = set(doc.lower().split())
        
        # Calculate intersection of words
        common_words = query_words.intersection(doc_words)
        
        # Avoid division by zero
        if len(query_words) == 0 or len(doc_words) == 0:
            return 0.0
            
        # Calculate score: 0.5 * (|V_q ∩ V_d|/|V_q| + |V_q ∩ V_d|/|V_d|)
        score = 0.5 * (len(common_words) / len(query_words) + len(common_words) / len(doc_words))
        
        return score

    def exact_match_search_query(self, query, docs):
        
        # Calculate scores for all documents
        scores = [(id_doc, self.exact_match_score(query, doc.page_content)) for (id_doc, doc) in enumerate(docs)]
        
        return scores
 

    def semantic_search_query(self, query: str, top_k: int=None) -> List[Dict[str, Any]]:
        actual_top_k = top_k or self.top_k
        if self.vectordb_factory.vectordb.vector_store is None:
            raise ValueError("Vector store not initialized. Store documents first.")

        # Generate embedding for query
        query_vector = self.vectordb_factory.vectordb.embedding_model.embed_query(query)

        # Perform similarity search
        results = self.vectordb_factory.vectordb.client.search(
            collection_name=self.vectordb_factory.vectordb.collection_name,
            data=[query_vector],
            limit=actual_top_k,
            output_fields=["text"],
            search_params={
                "metric_type": self.vectordb_factory.vectordb.metric_type
            },  # Use consistent metric type
        )[0]
        returned_docs = [(doc.id, doc.distance) for doc in results]
        returned_docs = sorted([doc for doc in returned_docs], key=lambda x: x[0], reverse=False)
        return returned_docs
    
    def query_fusion_score(self, query: str, top_k: int=None, threshold: float=None, w_semantic: float=0.5):
        """Query a list of documents based on exact matching and semantic scores. Return a list of similar documents.
        Args:
            query (str): The query to search for.
            top_k (int): The number of documents to return. Defaults to self.top_k.
            threshold (float): The minimum fusion score to return. Defaults to None.
            w_semantic (float): The weight of the semantic score. Defaults to 0.5.
        Returns:
            list: A list of similar documents.
        """
        exact_match_scores = self.exact_match_search_query(query=query, docs=self.docs)
        semantic_scores = self.semantic_search_query(query=query, top_k=len(self.docs))
        scores = [
            (
                id_exac, 
                { 
                    "semantic_score": seman_score,
                    "exac_score": exac_score,
                    "fusion_score":(1-w_semantic)*exac_score + w_semantic*seman_score 
                }
            )
            for ((id_exac, exac_score), (id_seman, seman_score)) 
                in list(zip(exact_match_scores, semantic_scores))
        ]
        sorted_scores = sorted(scores, key=lambda x: x[1]["fusion_score"], reverse=True)[:min(top_k, len(self.docs))]
        sorted_docs = [(self.docs[i], score) for (i, score) in sorted_scores]
        if threshold:
            filter_docs = [doc for (doc, score) in sorted_docs if score['fusion_score'] > threshold]
            return filter_docs
        else:
            return sorted_docs

Test `query_fusion_score`, which fuses between `exact match` and `semantic score`, to find a list of legal cases related to `Sales Tax`.

In [None]:
search_legal_engine = SearchLegalEngine(
    top_k=5, 
    temp_data_path=Path("data/test.parquet"),
    db_type="milvus",
    embedding_model="BAAI/bge-small-en-v1.5"
)

search_legal_engine._initialize_document_processor()

In [None]:
query = "claimed exemption from Sales Tax"
search_legal_engine.query_fusion_score(query, top_k=5, w_semantic=0.7)

Now, we need to write into the module file `vinagent/tools/legal_assistant/search_legal_cases.py`. Which will be loaded as a search tool in the next case.

In [None]:
%%writefile /Users/phamdinhkhanh/Documents/Courses/Manus/vinagent/vinagent/tools/legal_assistant/search_legal_cases.py
import pandas as pd
from datasets import load_dataset
from pathlib import Path
from langchain_core.documents import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from aucodb.vectordb.factory import VectorDatabaseFactory
from aucodb.vectordb.processor import DocumentProcessor
from typing import Literal, Any, Dict, List
from vinagent.register import primary_function


class SearchLegalEngine:
    def __init__(self, 
            top_k: int=5, 
            temp_data_path: Path = Path("data/test.parquet"),
            db_type: Literal['chroma', 'faiss', 'milvus', 'pgvector', 'pinecone', 'qdrant', 'weaviate'] = "milvus",
            embedding_model: str="BAAI/bge-small-en-v1.5"
        ):
        self.top_k = top_k
        self.embedding_model = HuggingFaceEmbeddings(model_name=embedding_model)
        self.temp_data_path = temp_data_path
        self.db_type = db_type
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200
        )
        self.doc_processor = DocumentProcessor(splitter=self.text_splitter)

    
    def _download_legal_case(self):
        dataset = load_dataset("joelniklaus/legal_case_document_summarization", split="test")
        dataset.to_parquet(self.temp_data_path)
        

    def _create_legal_cases_data(self):
        self._download_legal_case()
        legal_case = pd.read_parquet(self.temp_data_path)

        self.docs = []
        for (i, doc) in legal_case.iterrows():
            doc = Document(
                page_content=doc['judgement'], 
                metadata={
                    "judgement_case": i, 
                    "dataset_name": doc["dataset_name"],
                    "summary": doc["summary"]
                })
            self.docs.append(doc)
        return self.docs


    def _initialize_document_processor(self):
        self.vectordb_factory = VectorDatabaseFactory(
            db_type=self.db_type,
            embedding_model=self.embedding_model,
            doc_processor=self.doc_processor
        )
        self._create_legal_cases_data()        
        self.vectordb_factory.store_documents(self.docs)

    def exact_match_score(self, query, doc):
        # Convert strings to sets of words (case-insensitive, removing punctuation)
        query_words = set(query.lower().split())
        doc_words = set(doc.lower().split())
        
        # Calculate intersection of words
        common_words = query_words.intersection(doc_words)
        
        # Avoid division by zero
        if len(query_words) == 0 or len(doc_words) == 0:
            return 0.0
            
        # Calculate score: 0.5 * (|V_q ∩ V_d|/|V_q| + |V_q ∩ V_d|/|V_d|)
        score = 0.5 * (len(common_words) / len(query_words) + len(common_words) / len(doc_words))
        
        return score

    def exact_match_search_query(self, query, docs):
        # Calculate scores for all documents
        scores = [(id_doc, self.exact_match_score(query, doc.page_content)) for (id_doc, doc) in enumerate(docs)]
        return scores
 

    def semantic_search_query(self, query: str, top_k: int=None) -> List[Dict[str, Any]]:
        actual_top_k = top_k or self.top_k
        if self.vectordb_factory.vectordb.vector_store is None:
            raise ValueError("Vector store not initialized. Store documents first.")

        # Generate embedding for query
        query_vector = self.vectordb_factory.vectordb.embedding_model.embed_query(query)

        # Perform similarity search
        results = self.vectordb_factory.vectordb.client.search(
            collection_name=self.vectordb_factory.vectordb.collection_name,
            data=[query_vector],
            limit=actual_top_k,
            output_fields=["text"],
            search_params={
                "metric_type": self.vectordb_factory.vectordb.metric_type
            },  # Use consistent metric type
        )[0]
        returned_docs = [(doc.id, doc.distance) for doc in results]
        returned_docs = sorted([doc for doc in returned_docs], key=lambda x: x[0], reverse=False)
        return returned_docs

    def query_fusion_score(self, query: str, top_k: int=None, threshold: float=None, w_semantic: float=0.5):
        """Query a list of documents based on exact matching and semantic scores. Return a list of similar documents.
        Args:
            query (str): The query to search for.
            top_k (int): The number of documents to return. Defaults to self.top_k.
            threshold (float): The minimum fusion score to return. Defaults to None.
            w_semantic (float): The weight of the semantic score. Defaults to 0.5.
        Returns:
            list: A list of similar documents.
        """
        exact_match_scores = self.exact_match_search_query(query=query, docs=self.docs)
        semantic_scores = self.semantic_search_query(query=query, top_k=len(self.docs))
        scores = [
            (
                id_exac, 
                { 
                    "semantic_score": seman_score,
                    "exac_score": exac_score,
                    "fusion_score":(1-w_semantic)*exac_score + w_semantic*seman_score 
                }
            )
            for ((id_exac, exac_score), (id_seman, seman_score)) 
                in list(zip(exact_match_scores, semantic_scores))
        ]
        sorted_scores = sorted(scores, key=lambda x: x[1]["fusion_score"], reverse=True)[:min(top_k, len(self.docs))]
        sorted_docs = [(self.docs[i], score) for (i, score) in sorted_scores]
        if threshold:
            filter_docs = [doc for (doc, score) in sorted_docs if score['fusion_score'] > threshold]
            return filter_docs
        else:
            return sorted_docs

@primary_function
def query_similar_legal_cases(query: str, n_legal_cases: int=2, threshold: float=0.6):
    """Query the similar legal cases to the given query.
    Args:
        query (str): The query string.
        n_legal_cases (int): The number of legal cases
        threshold (float): The similarity threshold. Defaults to 0.6.
    
    Returns:
        The similar legal cases.
    """
    search_legal_engine = SearchLegalEngine(
        top_k=n_legal_cases, 
        temp_data_path=Path("data/test.parquet"),
        db_type="milvus",
        embedding_model="BAAI/bge-small-en-v1.5"
    )
    search_legal_engine._create_legal_cases_data()
    search_legal_engine._initialize_document_processor()
    docs = search_legal_engine.query_fusion_score(query, top_k=n_legal_cases, threshold=threshold, w_semantic=0.7)
    return docs

# Initialize Legal Agent

We demonstrate how to initialize a legal assistant on vinagent, which can assist users with tasks like:

- Search the relevant legal cases.
- Summarize the major timeline of events in a certain legal case.
- Proceed arguments analysis to define the strength and weakness of appellants' arguments.
- Jurisdictional analysis of the Act and Regulation. 
- Analyze the ethical and bias in court ruling.

In [None]:
from vinagent.agent.agent import Agent
from langchain_openai import ChatOpenAI
from vinagent.agent.agent import Agent
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv('.env'))

llm = ChatOpenAI(
    model = "o4-mini"
)

legal_agent = Agent(
    name="Legal Assistant",
    description="A legal assistant who can find the similar legal cases",
    llm = llm,
    skills=[
        "search similar legal cases",
        "summary the legal cases",
        "extract the main entities in the legal cases",
        "search information on the internet"
    ],
    tools=[
        '/Users/phamdinhkhanh/Documents/Courses/Manus/vinagent/vinagent/tools/legal_assistant/search_legal_cases.py',
        '/Users/phamdinhkhanh/Documents/Courses/Manus/vinagent/vinagent/tools/websearch_tools.py'
    ]
)

# Find the relevant legal case

Attorney usually finds the relevant legal cases to prepare before starting a lawsuit. The primary target of finding similar legal cases is to identify relevant precedents that guide the resolution of a current case, ensuring consistency and fairness in legal outcomes. By researching cases with comparable facts or legal issues, attorneys can build stronger arguments, predict judicial rulings, and uncover defenses or counterarguments. This process supports compliance with the principle of stare decisis, enhances case strategy, and provides leverage in negotiations, ultimately saving time and resources while grounding legal decisions in established judicial authority.

In [None]:
message = legal_agent.invoke(
    "Let find one legal case claimed exemption sales tax", 
    is_tool_formatted=False,
    max_history=1
)
message

This is the main content of similar legal case.

In [None]:
message.artifact

In there, you only use `is_tool_formatted=False` to disable the next step of modifying the tool message. We set `max_history=1` to use current query and remove the history context is to ensure the context length does not exceed the maximum length of llm acceptance criteria.

In [None]:
legal_agent.in_conversation_history.get_history()

The history only return a list of messages ending with `ToolMessage`. If you want Agent to modify the `ToolMessage` as human preference. Let's turn `is_tool_formatted=True`.

In [None]:
message = legal_agent.invoke(
    "Let find one legal case claimed exemption sales tax", 
    is_tool_formatted=True,
    max_history=1
)
message

In [None]:
legal_agent.in_conversation_history.get_history()

By default, Vinagent agent can store up to last 10 messages inside it's conversation history. Therefore, if we continue the query, a list of answer will append to the existing history. In this case, you accept to modify the tool result that means you will obtain `AIMessage` at the last.

# Summarize legal case

With very long legal case, we can not capture in detailed each events. Therefore we need summarize the legal case in a short form to accelerate the reading speed of attorneys.

In [None]:
legal_case = docs[199].page_content
message = legal_agent.invoke(f"Let's summarize this legal case in 200 words including context, development, plaintiff's arguments, and court ruling \n{legal_case}", max_history=1)

In [None]:
print(message.content)

# Timeline and Fact Organization

We can structure the timeline of events for each legal case in descending order. Thus, it will help to track the event follow better.

In [None]:
message = legal_agent.invoke(f"Let's create a timeline of events in this legal case in descending order: \n{legal_case}", max_history=1)

In [None]:
print(message.content)

# Argument analysis

Sometimes, attorney dives deepth to understand the strength and weakness of appellants' arguments. This is to ensure they can increase the probability of win before the trial begins. legal_agent can also deeply analyze the strength and weakness.

In [None]:
message = legal_agent.invoke(f"Let's analyze the strengths and weaknesses of appellants' arguments: \n{legal_case}", max_history=1)

In [None]:
print(message.content)

# Jurisdictional Analysis

Jurisdictional analysis is vital in legal proceedings to ensure challenges are pursued correctly and efficiently. It serves to identify the correct legal framework, ensure compliance with time limits, define the scope of review, clarify court hierarchy and appeal routes, guide remedies and outcomes, and align with statutory interplay. By addressing these aspects, jurisdictional analysis prevents procedural errors, focuses arguments on permissible legal grounds, and informs strategic decisions, thereby upholding the integrity of the judicial process.

In [None]:
message = legal_agent.invoke(f"Let's analyze the jurisdictional analysis: \n{legal_case}", max_history=1)
print(message.content)

# Ethical and Bias in court ruling

Court rulings need to be ethical and free from bias to deliver fair, open, and responsible decisions, especially in tricky cases while many unfair judgements were made. It’s about ensuring justice for future generations, weighing economic gains against environmental and local community impacts, being transparent, handling scientific unknowns carefully, avoiding institutional blind spots, and striking the right balance in judicial oversight. If courts don’t tackle these ethical and bias issues head-on, they risk deepening inequalities, weakening environmental protections, and losing the public’s trust in the system.

In [None]:
message = legal_agent.invoke(f"Let's consider the ethical and bias arguments of court ruling for this legal case: \n{legal_case}", max_history=1)
print(message.content)