In [8]:
%load_ext autoreload
%autoreload 2

# Systematically Improving RAG

We want to generate synthetic transactions for a given set of categories that we have. This will allow us to see how well an embedding based approach can retrieve the correct category for a given transaction. In a later notebook, we'll use this data to fine tune a re-ranker model that can do so

In [89]:
import json

categories = json.loads(open("./data/categories.json").read())
categories[:2]

[{'name': '6010-CLOUD-COMPUTING',
  'description': 'Cloud infrastructure costs including AWS, Azure, and Google Cloud Platform services'},
 {'name': '6015-CDN-SERVICES',
  'description': 'Content Delivery Network services for global content distribution'}]

In [90]:
departments = json.loads(open("./data/departments.json").read())
departments[:2]


['Sales', 'Marketing']

We want to generate some transactions for each category. We'll use the category description and code to generate a transaction. 


In [112]:
import instructor
from openai import AsyncOpenAI
from pydantic import BaseModel, Field
from typing import Optional
import random

client = instructor.from_openai(AsyncOpenAI())


class FakeTransaction(BaseModel):
    chain_of_thought: str
    merchant_name: str
    amount: float
    department: str = Field(description="Represents the department the transaction is associated with")
    location: str = Field(description="Transaction Charge Location")
    mccs: list[str] = Field(description="List of Merchant Categories that apply to the transaction (Eg. Restaurants, Office Supplies, Rideshare, etc). This should not be a number")
    card: str = Field(description="The spend program used to make the transaction")
    trip_name: Optional[str] = Field(description="Represents the name of the trip the transaction is associated with ( if on travel )")
    remarks:str = Field(description="Any additional information about the transaction that is not covered by the other fields")
    
async def generate_transaction(category,sem):


    async with sem:
        return (
            await client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {
                    "role": "system",
                    "content": """
                    Generate transaction data that can belong to the given category of {{ category.name}} which has a description of {{ category.description }}.

                    This transaction should be realistic and should only be able to be filed under the given category.

                    Here are the categories that exist in your system ( including the given category):
                    { % for category in categories % }
                    - { category.name } : { category.description }
                    { % endfor % }
                    """,
                },{
                    "role": "user",
                    "content": """
                    Keep the following in mind:
                    - Merchant names should not reflect the category (Eg. Toyko Office Depot is a bad name, instead we could call it names like Popular, Takeshi and Sons, New Haven etc)
                    - The location should be a real city
                    - Merchant Category Codes should be a list of MCCs that apply to the transaction. This should not be a number but instead a category (Eg. Restaurants, Office Supplies, Rideshare, etc)
                    - The card field represents a spend program which the transaction would be associated with
                    - Only provide a trip name if the transaction is one that occurs on a business trip
                    """
                }
            ],
            context={"category": category, "categories": categories},
            response_model=FakeTransaction,
        ),
        category,
    )


Let's now generate 4 transactions for each category and see how well we can retrieve the correct category for a given transaction.

In [113]:
from tqdm.asyncio import tqdm_asyncio as asyncio
from asyncio import Semaphore

sem = Semaphore(10)

coros = []

questions = 30

for _ in range(questions):
        coros.append(generate_transaction(random.choice(categories),sem))

transactions = await asyncio.gather(*coros)

100%|██████████| 30/30 [00:10<00:00,  2.98it/s]


Now let's dump this into LanceDB and see how well we can retrieve the correct category for a given transaction.

In [119]:
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry

func = get_registry().get("openai").create(name="text-embedding-3-small")


class Category(LanceModel):
    name: str
    description: str
    text: str = func.SourceField()
    embedding: Vector(func.ndims()) = func.VectorField()


db = lancedb.connect("./lancedb")
table = db.create_table("categories", schema=Category, mode="overwrite")


table.add(
    [
        {
            "name": category["name"],
            "description": category["description"],
            "text": """
            name: {category["name"]}
            description: {category["description"]}
            """
        }
        for category in categories
    ]
)

table.create_fts_index(field_names=["text"], replace=True)

In [94]:
table.count_rows()

64

In [120]:
from braintrust import Eval, Score
import itertools


def calculate_mrr(predictions: list[str], gt: list[str]):
    mrr = 0
    for label in gt:
        if label in predictions:
            mrr = max(mrr, 1 / (predictions.index(label) + 1))
    return mrr


def get_recall(predictions: list[str], gt: list[str]):
    return len([label for label in gt if label in predictions]) / len(gt)


eval_metrics = [["mrr", calculate_mrr], ["recall", get_recall]]
sizes = [3, 5, 10, 15, 25]

metrics = {
    f"{metric_name}@{size}": lambda predictions, gt, m=metric_fn, s=size: (
        lambda p, g: m(p[:s], g)
    )(predictions, gt)
    for (metric_name, metric_fn), size in itertools.product(eval_metrics, sizes)
}


def evaluate_braintrust(input, output, **kwargs):
    return [
        Score(
            name=metric,
            score=score_fn(output, kwargs["expected"]),
            metadata={"query": input, "result": output, **kwargs["metadata"]},
        )
        for metric, score_fn in metrics.items()
    ]


def task(input):
    return [
        item["name"]
        for item in table.search(input, query_type="vector")
        .select(["name"])
        .limit(25)
        .to_list()
    ]


await Eval(
    "fine-tuning",  # Replace with your project name
    data=lambda: [
        {
            "input": f"""
            Name: {transaction.merchant_name}
            Category: {', '.join(transaction.mccs)}
            Department: {transaction.department}
            Amount: {transaction.amount}
            Location: {transaction.location}
            Card: {transaction.card}
            Trip Name: {transaction.trip_name if transaction.trip_name else "unknown"}
            
            Remarks: {transaction.remarks if transaction.remarks else "unknown"}
            """,
            "expected": [label["name"]],
        }
        for transaction,label in transactions
    ],  # Replace with your eval dataset
    task=task,  # Replace with your LLM call
    scores=[evaluate_braintrust],
)

Experiment fine-tuning-1730725020 is running at https://www.braintrust.dev/app/567/p/fine-tuning/experiments/fine-tuning-1730725020
fine-tuning (data): 30it [00:00, 16267.50it/s]


fine-tuning (tasks):   0%|          | 0/30 [00:00<?, ?it/s]


fine-tuning-1730725020 compared to fine-tuning-1730723348:
02.22% 'mrr@3'     score
02.89% 'mrr@5'     score
03.92% 'mrr@10'    score
03.92% 'mrr@15'    score
04.79% 'mrr@25'    score
06.67% 'recall@3'  score
10.00% 'recall@5'  score
16.67% 'recall@10' score
16.67% 'recall@15' score
33.33% 'recall@25' score

1.15s duration

See results for fine-tuning-1730725020 at https://www.braintrust.dev/app/567/p/fine-tuning/experiments/fine-tuning-1730725020


EvalResultWithSummary(summary="...", results=[...])

In [118]:
with open("./data/transactions.json", "w") as f:
    for transaction,label in transactions:
        f.write(json.dumps({
            "merchant_name": transaction.merchant_name,
            "amount": transaction.amount,
            "department": transaction.department,
            "location": transaction.location,
            "mccs": transaction.mccs,
            "card": transaction.card,
            "trip_name": transaction.trip_name if transaction.trip_name else '',
            "remarks": transaction.remarks if transaction.remarks else '',
            "label": label
        }) + "\n")
