In [None]:
import json
from enum import Enum

from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain_community.chat_models import ChatOllama
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI

In [None]:
import yaml

with open("../secrets.yaml", "r") as file:
    SECRETS = yaml.safe_load(file)

## Define models

In [None]:
class ModelChoice(Enum):
    llama2 = "llama2"
    gpt4 = "gpt-4"


def supports_function_calling(model: BaseChatModel):
    return hasattr(model, "bind_functions")


TEMPERATURE = 0
OPENAI_API_KEY = SECRETS.get("OPENAI_API_KEY")


def model_factory(model: ModelChoice):
    if model == ModelChoice.llama2:
        return ChatOllama(model="llama2", temperature=TEMPERATURE)
    elif model == ModelChoice.gpt4:
        return ChatOpenAI(
            model="gpt-4", temperature=TEMPERATURE, openai_api_key=OPENAI_API_KEY
        )
    else:
        raise ValueError(f"Model {model} not supported")

## OutputParsers

In [None]:
response_schemas = [
    ResponseSchema(
        name="description",
        description="Original description of the transaction. Just pass input here description here",
    ),
    ResponseSchema(name="category", description="Parent category of the transaction"),
    ResponseSchema(name="subcategory", description="Subcategory of the transaction"),
    ResponseSchema(
        name="reasoning",
        description="Explanation of the transaction and why it's categorized as such",
    ),
]
structured_parser = StructuredOutputParser.from_response_schemas(response_schemas)

In [None]:
class ResponseSchema(BaseModel):
    """Output schema for the response from the model"""

    description: str = Field(
        description="Original description of the transaction. Just pass input here description here"
    )
    category: str = Field(description="Parent category of the transaction")
    subcategory: str = Field(description="Subcategory of the transaction")
    reasoning: str = Field(
        description="Explanation of the transaction and why it's categorized as such"
    )

## Define Prompt

In [None]:
prompt = PromptTemplate(
    template="""
    You are an personal financial assistant that categorizes 
    bank transfers for transactions in Sweden. The categories
    should be useful for managing and visualizing personal budget 
    and expenses. Categorize the following transaction: "{description}".
    {format_instructions}
    """,
    input_variables=["description"],
    partial_variables={
        "format_instructions": structured_parser.get_format_instructions()
    },
)

## Chain Components and Invoke Chain

In [None]:
model = model_factory(ModelChoice.llama2)
# model = model_factory(ModelChoice.gpt4)

chain = prompt | model | structured_parser

In [None]:
description = "Netflix"
response = chain.invoke({"description": description})
print(json.dumps(response, indent=2))